Sssplendid commited on
Commit
1ece77f
·
verified ·
1 Parent(s): ae93aaa

Add gated_deltanet_1b_v3/gated_deltanet_1b_adamw_lr1e_3_b1_0_9_b2_0_99_eps_1e_8_20260503_013622

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +9 -0
  2. gated_deltanet_1b_v3/gated_deltanet_1b_adamw_lr1e_3_b1_0_9_b2_0_99_eps_1e_8_20260503_013622/exp_data/LICENSE +21 -0
  3. gated_deltanet_1b_v3/gated_deltanet_1b_adamw_lr1e_3_b1_0_9_b2_0_99_eps_1e_8_20260503_013622/exp_data/README.md +519 -0
  4. gated_deltanet_1b_v3/gated_deltanet_1b_adamw_lr1e_3_b1_0_9_b2_0_99_eps_1e_8_20260503_013622/exp_data/checkpoint/step-30720/.metadata +3 -0
  5. gated_deltanet_1b_v3/gated_deltanet_1b_adamw_lr1e_3_b1_0_9_b2_0_99_eps_1e_8_20260503_013622/exp_data/checkpoint/step-30720/__0_0.distcp +3 -0
  6. gated_deltanet_1b_v3/gated_deltanet_1b_adamw_lr1e_3_b1_0_9_b2_0_99_eps_1e_8_20260503_013622/exp_data/checkpoint/step-30720/__1_0.distcp +3 -0
  7. gated_deltanet_1b_v3/gated_deltanet_1b_adamw_lr1e_3_b1_0_9_b2_0_99_eps_1e_8_20260503_013622/exp_data/checkpoint/step-30720/__2_0.distcp +3 -0
  8. gated_deltanet_1b_v3/gated_deltanet_1b_adamw_lr1e_3_b1_0_9_b2_0_99_eps_1e_8_20260503_013622/exp_data/checkpoint/step-30720/__3_0.distcp +3 -0
  9. gated_deltanet_1b_v3/gated_deltanet_1b_adamw_lr1e_3_b1_0_9_b2_0_99_eps_1e_8_20260503_013622/exp_data/checkpoint/step-30720/__4_0.distcp +3 -0
  10. gated_deltanet_1b_v3/gated_deltanet_1b_adamw_lr1e_3_b1_0_9_b2_0_99_eps_1e_8_20260503_013622/exp_data/checkpoint/step-30720/__5_0.distcp +3 -0
  11. gated_deltanet_1b_v3/gated_deltanet_1b_adamw_lr1e_3_b1_0_9_b2_0_99_eps_1e_8_20260503_013622/exp_data/checkpoint/step-30720/__6_0.distcp +3 -0
  12. gated_deltanet_1b_v3/gated_deltanet_1b_adamw_lr1e_3_b1_0_9_b2_0_99_eps_1e_8_20260503_013622/exp_data/checkpoint/step-30720/__7_0.distcp +3 -0
  13. gated_deltanet_1b_v3/gated_deltanet_1b_adamw_lr1e_3_b1_0_9_b2_0_99_eps_1e_8_20260503_013622/exp_data/config.json +36 -0
  14. gated_deltanet_1b_v3/gated_deltanet_1b_adamw_lr1e_3_b1_0_9_b2_0_99_eps_1e_8_20260503_013622/exp_data/configs/blt_transformer_1000hash.json +98 -0
  15. gated_deltanet_1b_v3/gated_deltanet_1b_adamw_lr1e_3_b1_0_9_b2_0_99_eps_1e_8_20260503_013622/exp_data/configs/blt_transformer_1_5B.json +99 -0
  16. gated_deltanet_1b_v3/gated_deltanet_1b_adamw_lr1e_3_b1_0_9_b2_0_99_eps_1e_8_20260503_013622/exp_data/configs/blt_transformer_380M.json +98 -0
  17. gated_deltanet_1b_v3/gated_deltanet_1b_adamw_lr1e_3_b1_0_9_b2_0_99_eps_1e_8_20260503_013622/exp_data/configs/delta_net_1B.json +29 -0
  18. gated_deltanet_1b_v3/gated_deltanet_1b_adamw_lr1e_3_b1_0_9_b2_0_99_eps_1e_8_20260503_013622/exp_data/configs/delta_net_340M.json +26 -0
  19. gated_deltanet_1b_v3/gated_deltanet_1b_adamw_lr1e_3_b1_0_9_b2_0_99_eps_1e_8_20260503_013622/exp_data/configs/gated_deltanet_1B.json +22 -0
  20. gated_deltanet_1b_v3/gated_deltanet_1b_adamw_lr1e_3_b1_0_9_b2_0_99_eps_1e_8_20260503_013622/exp_data/configs/gated_deltanet_340M.json +22 -0
  21. gated_deltanet_1b_v3/gated_deltanet_1b_adamw_lr1e_3_b1_0_9_b2_0_99_eps_1e_8_20260503_013622/exp_data/configs/gated_deltanet_h_340M.json +28 -0
  22. gated_deltanet_1b_v3/gated_deltanet_1b_adamw_lr1e_3_b1_0_9_b2_0_99_eps_1e_8_20260503_013622/exp_data/configs/gla_340M.json +24 -0
  23. gated_deltanet_1b_v3/gated_deltanet_1b_adamw_lr1e_3_b1_0_9_b2_0_99_eps_1e_8_20260503_013622/exp_data/configs/gla_7B.json +25 -0
  24. gated_deltanet_1b_v3/gated_deltanet_1b_adamw_lr1e_3_b1_0_9_b2_0_99_eps_1e_8_20260503_013622/exp_data/configs/gsa_340M.json +29 -0
  25. gated_deltanet_1b_v3/gated_deltanet_1b_adamw_lr1e_3_b1_0_9_b2_0_99_eps_1e_8_20260503_013622/exp_data/configs/mergenet_340M.json +34 -0
  26. gated_deltanet_1b_v3/gated_deltanet_1b_adamw_lr1e_3_b1_0_9_b2_0_99_eps_1e_8_20260503_013622/exp_data/configs/mergenet_64M.json +34 -0
  27. gated_deltanet_1b_v3/gated_deltanet_1b_adamw_lr1e_3_b1_0_9_b2_0_99_eps_1e_8_20260503_013622/exp_data/configs/qwen3_next_1B.json +44 -0
  28. gated_deltanet_1b_v3/gated_deltanet_1b_adamw_lr1e_3_b1_0_9_b2_0_99_eps_1e_8_20260503_013622/exp_data/configs/qwen3_next_350M.json +44 -0
  29. gated_deltanet_1b_v3/gated_deltanet_1b_adamw_lr1e_3_b1_0_9_b2_0_99_eps_1e_8_20260503_013622/exp_data/configs/transformer_1B.json +22 -0
  30. gated_deltanet_1b_v3/gated_deltanet_1b_adamw_lr1e_3_b1_0_9_b2_0_99_eps_1e_8_20260503_013622/exp_data/configs/transformer_340M.json +18 -0
  31. gated_deltanet_1b_v3/gated_deltanet_1b_adamw_lr1e_3_b1_0_9_b2_0_99_eps_1e_8_20260503_013622/exp_data/configs/transformer_7B.json +21 -0
  32. gated_deltanet_1b_v3/gated_deltanet_1b_adamw_lr1e_3_b1_0_9_b2_0_99_eps_1e_8_20260503_013622/exp_data/flame/__init__.py +1 -0
  33. gated_deltanet_1b_v3/gated_deltanet_1b_adamw_lr1e_3_b1_0_9_b2_0_99_eps_1e_8_20260503_013622/exp_data/flame/__pycache__/__init__.cpython-310.pyc +0 -0
  34. gated_deltanet_1b_v3/gated_deltanet_1b_adamw_lr1e_3_b1_0_9_b2_0_99_eps_1e_8_20260503_013622/exp_data/flame/__pycache__/__init__.cpython-311.pyc +0 -0
  35. gated_deltanet_1b_v3/gated_deltanet_1b_adamw_lr1e_3_b1_0_9_b2_0_99_eps_1e_8_20260503_013622/exp_data/flame/__pycache__/__init__.cpython-313.pyc +0 -0
  36. gated_deltanet_1b_v3/gated_deltanet_1b_adamw_lr1e_3_b1_0_9_b2_0_99_eps_1e_8_20260503_013622/exp_data/flame/__pycache__/config_manager.cpython-310.pyc +0 -0
  37. gated_deltanet_1b_v3/gated_deltanet_1b_adamw_lr1e_3_b1_0_9_b2_0_99_eps_1e_8_20260503_013622/exp_data/flame/__pycache__/config_manager.cpython-311.pyc +0 -0
  38. gated_deltanet_1b_v3/gated_deltanet_1b_adamw_lr1e_3_b1_0_9_b2_0_99_eps_1e_8_20260503_013622/exp_data/flame/__pycache__/data.cpython-310.pyc +0 -0
  39. gated_deltanet_1b_v3/gated_deltanet_1b_adamw_lr1e_3_b1_0_9_b2_0_99_eps_1e_8_20260503_013622/exp_data/flame/__pycache__/data.cpython-311.pyc +0 -0
  40. gated_deltanet_1b_v3/gated_deltanet_1b_adamw_lr1e_3_b1_0_9_b2_0_99_eps_1e_8_20260503_013622/exp_data/flame/__pycache__/train.cpython-310.pyc +0 -0
  41. gated_deltanet_1b_v3/gated_deltanet_1b_adamw_lr1e_3_b1_0_9_b2_0_99_eps_1e_8_20260503_013622/exp_data/flame/__pycache__/train.cpython-311.pyc +0 -0
  42. gated_deltanet_1b_v3/gated_deltanet_1b_adamw_lr1e_3_b1_0_9_b2_0_99_eps_1e_8_20260503_013622/exp_data/flame/__pycache__/train.cpython-313.pyc +0 -0
  43. gated_deltanet_1b_v3/gated_deltanet_1b_adamw_lr1e_3_b1_0_9_b2_0_99_eps_1e_8_20260503_013622/exp_data/flame/c4_test.py +603 -0
  44. gated_deltanet_1b_v3/gated_deltanet_1b_adamw_lr1e_3_b1_0_9_b2_0_99_eps_1e_8_20260503_013622/exp_data/flame/components/__init__.py +0 -0
  45. gated_deltanet_1b_v3/gated_deltanet_1b_adamw_lr1e_3_b1_0_9_b2_0_99_eps_1e_8_20260503_013622/exp_data/flame/components/__pycache__/__init__.cpython-310.pyc +0 -0
  46. gated_deltanet_1b_v3/gated_deltanet_1b_adamw_lr1e_3_b1_0_9_b2_0_99_eps_1e_8_20260503_013622/exp_data/flame/components/__pycache__/__init__.cpython-311.pyc +0 -0
  47. gated_deltanet_1b_v3/gated_deltanet_1b_adamw_lr1e_3_b1_0_9_b2_0_99_eps_1e_8_20260503_013622/exp_data/flame/components/__pycache__/checkpoint.cpython-310.pyc +0 -0
  48. gated_deltanet_1b_v3/gated_deltanet_1b_adamw_lr1e_3_b1_0_9_b2_0_99_eps_1e_8_20260503_013622/exp_data/flame/components/__pycache__/checkpoint.cpython-311.pyc +0 -0
  49. gated_deltanet_1b_v3/gated_deltanet_1b_adamw_lr1e_3_b1_0_9_b2_0_99_eps_1e_8_20260503_013622/exp_data/flame/components/checkpoint.py +59 -0
  50. gated_deltanet_1b_v3/gated_deltanet_1b_adamw_lr1e_3_b1_0_9_b2_0_99_eps_1e_8_20260503_013622/exp_data/flame/config_manager.py +981 -0
.gitattributes CHANGED
@@ -321,3 +321,12 @@ gated_deltanet_1b_v3/gated_deltanet_1b_mars_shampoo_lr2e_2_b1_0_95_b2_0_99_eps_1
321
  gated_deltanet_1b_v3/gated_deltanet_1b_mars_shampoo_lr2e_2_b1_0_95_b2_0_99_eps_1e_15_scale2_0_rank512_20260506_140829/exp_data/checkpoint/step-30720/__5_0.distcp filter=lfs diff=lfs merge=lfs -text
322
  gated_deltanet_1b_v3/gated_deltanet_1b_mars_shampoo_lr2e_2_b1_0_95_b2_0_99_eps_1e_15_scale2_0_rank512_20260506_140829/exp_data/checkpoint/step-30720/__6_0.distcp filter=lfs diff=lfs merge=lfs -text
323
  gated_deltanet_1b_v3/gated_deltanet_1b_mars_shampoo_lr2e_2_b1_0_95_b2_0_99_eps_1e_15_scale2_0_rank512_20260506_140829/exp_data/checkpoint/step-30720/__7_0.distcp filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
321
  gated_deltanet_1b_v3/gated_deltanet_1b_mars_shampoo_lr2e_2_b1_0_95_b2_0_99_eps_1e_15_scale2_0_rank512_20260506_140829/exp_data/checkpoint/step-30720/__5_0.distcp filter=lfs diff=lfs merge=lfs -text
322
  gated_deltanet_1b_v3/gated_deltanet_1b_mars_shampoo_lr2e_2_b1_0_95_b2_0_99_eps_1e_15_scale2_0_rank512_20260506_140829/exp_data/checkpoint/step-30720/__6_0.distcp filter=lfs diff=lfs merge=lfs -text
323
  gated_deltanet_1b_v3/gated_deltanet_1b_mars_shampoo_lr2e_2_b1_0_95_b2_0_99_eps_1e_15_scale2_0_rank512_20260506_140829/exp_data/checkpoint/step-30720/__7_0.distcp filter=lfs diff=lfs merge=lfs -text
324
+ gated_deltanet_1b_v3/gated_deltanet_1b_adamw_lr1e_3_b1_0_9_b2_0_99_eps_1e_8_20260503_013622/exp_data/checkpoint/step-30720/.metadata filter=lfs diff=lfs merge=lfs -text
325
+ gated_deltanet_1b_v3/gated_deltanet_1b_adamw_lr1e_3_b1_0_9_b2_0_99_eps_1e_8_20260503_013622/exp_data/checkpoint/step-30720/__0_0.distcp filter=lfs diff=lfs merge=lfs -text
326
+ gated_deltanet_1b_v3/gated_deltanet_1b_adamw_lr1e_3_b1_0_9_b2_0_99_eps_1e_8_20260503_013622/exp_data/checkpoint/step-30720/__1_0.distcp filter=lfs diff=lfs merge=lfs -text
327
+ gated_deltanet_1b_v3/gated_deltanet_1b_adamw_lr1e_3_b1_0_9_b2_0_99_eps_1e_8_20260503_013622/exp_data/checkpoint/step-30720/__2_0.distcp filter=lfs diff=lfs merge=lfs -text
328
+ gated_deltanet_1b_v3/gated_deltanet_1b_adamw_lr1e_3_b1_0_9_b2_0_99_eps_1e_8_20260503_013622/exp_data/checkpoint/step-30720/__3_0.distcp filter=lfs diff=lfs merge=lfs -text
329
+ gated_deltanet_1b_v3/gated_deltanet_1b_adamw_lr1e_3_b1_0_9_b2_0_99_eps_1e_8_20260503_013622/exp_data/checkpoint/step-30720/__4_0.distcp filter=lfs diff=lfs merge=lfs -text
330
+ gated_deltanet_1b_v3/gated_deltanet_1b_adamw_lr1e_3_b1_0_9_b2_0_99_eps_1e_8_20260503_013622/exp_data/checkpoint/step-30720/__5_0.distcp filter=lfs diff=lfs merge=lfs -text
331
+ gated_deltanet_1b_v3/gated_deltanet_1b_adamw_lr1e_3_b1_0_9_b2_0_99_eps_1e_8_20260503_013622/exp_data/checkpoint/step-30720/__6_0.distcp filter=lfs diff=lfs merge=lfs -text
332
+ gated_deltanet_1b_v3/gated_deltanet_1b_adamw_lr1e_3_b1_0_9_b2_0_99_eps_1e_8_20260503_013622/exp_data/checkpoint/step-30720/__7_0.distcp filter=lfs diff=lfs merge=lfs -text
gated_deltanet_1b_v3/gated_deltanet_1b_adamw_lr1e_3_b1_0_9_b2_0_99_eps_1e_8_20260503_013622/exp_data/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2023-2025 Songlin Yang, Yu Zhang
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
gated_deltanet_1b_v3/gated_deltanet_1b_adamw_lr1e_3_b1_0_9_b2_0_99_eps_1e_8_20260503_013622/exp_data/README.md ADDED
@@ -0,0 +1,519 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <div align="center">
2
+
3
+ # 🔥 Flame: Flash Language Modeling Made Easy
4
+
5
+ [![Ask DeepWiki](https://deepwiki.com/badge.svg)](https://deepwiki.com/fla-org/flame)
6
+
7
+ </div>
8
+
9
+ Welcome to 🔥 `flame`, a minimal and efficient framework built on `torchtitan` for language models with blazing efficiency.
10
+
11
+ **Feature Highlights:**
12
+
13
+ - 🚀 Minimal, easy-to-use, extensible training framework
14
+ - 🤗 Seamless integration with `fla` and `transformers`
15
+ - 🔄 Zero-cost data preprocessing: online tokenization, dataset shuffling, and multiple datasets support
16
+ - 🔮 4D parallelism (coming soon)
17
+
18
+ ## Setup
19
+
20
+ To get started, clone the `flame` repository and install the required dependencies:
21
+
22
+ ```bash
23
+ git clone https://github.com/fla-org/flame.git
24
+ cd flame
25
+ pip install .
26
+ ```
27
+
28
+ Install the latest version of fla
29
+ ```
30
+ pip uninstall flash-linear-attention && pip install -U --no-use-pep517 git+https://github.com/fla-org/flash-linear-attention
31
+ ```
32
+
33
+ [Important] Install specific version of torchtitan
34
+ ```
35
+ pip install git+https://github.com/pytorch/torchtitan.git@0b44d4c
36
+ ```
37
+
38
+
39
+ ## Dataset Preparation
40
+ To download the dataset to your local disk, create a new Python file with the following content and execute it:
41
+
42
+ ```py
43
+ from datasets import load_dataset
44
+
45
+ # load fineweb-edu with parallel processing
46
+ dataset = load_dataset("HuggingFaceFW/fineweb-edu", name="default", num_proc=64, cache_dir="/your/cache/path")
47
+
48
+ # or load a subset with roughly 100B tokens, suitable for small- or medium-sized experiments
49
+ dataset = load_dataset("HuggingFaceFW/fineweb-edu", name="sample-100BT", num_proc=64, cache_dir="/your/cache/path")
50
+ ```
51
+
52
+ ## Training Recipes
53
+
54
+ Here's an example of training a 340M FLA Transformer model with a LLaMA-like architecture from scratch on a 100BT subset of the Fineweb-edu corpus ~~in streaming mode~~. (Do not use streaming mode if you are concerned about resuming training.)
55
+
56
+ > [!WARNING]
57
+ > If the dataset is not downloaded beforehand, the streaming mode will attempt to fetch it from a remote server and download it on-the-fly, which can be highly unstable during training due to network issues.
58
+ > For stable training, ensure the dataset is downloaded locally (see [**Dataset Preparation**](#dataset-preparation)). Otherwise, we assume you are only testing the new corpus.
59
+
60
+ ```sh
61
+ bash train.sh \
62
+ --job.config_file flame/models/fla.toml \
63
+ --job.dump_folder exp/transformer-340M-4K-10B/batch1.seqlen65536.context4096.warmup1024.update1.steps20480.lr1e-3.cosine \
64
+ --model.config configs/transformer_340M.json \
65
+ --model.tokenizer_path fla-hub/transformer-1.3B-100B \
66
+ --optimizer.name AdamW \
67
+ --optimizer.eps 1e-15 \
68
+ --optimizer.lr 1e-3 \
69
+ --lr_scheduler.warmup_steps 1024 \
70
+ --lr_scheduler.lr_min 0.1 \
71
+ --lr_scheduler.decay_type cosine \
72
+ --training.batch_size 1 \
73
+ --training.seq_len 65536 \
74
+ --training.context_len 4096 \
75
+ --training.varlen \
76
+ --training.gradient_accumulation_steps 1 \
77
+ --training.steps 20480 \
78
+ --training.max_norm 1.0 \
79
+ --training.skip_nan_inf \
80
+ --training.dataset HuggingFaceFW/fineweb-edu \
81
+ --training.dataset_name sample-100BT \
82
+ --training.dataset_split train \
83
+ --training.num_workers 32 \
84
+ --training.prefetch_factor 2 \
85
+ --training.seed 42 \
86
+ --training.compile \
87
+ --checkpoint.interval 2048 \
88
+ --checkpoint.load_step -1 \
89
+ --checkpoint.keep_latest_k 2 \
90
+ --metrics.log_freq 1
91
+ ```
92
+
93
+ You can specify the number of GPUs by setting the environment variable `NGPU`, which defaults to 8.
94
+ **For single-GPU debugging, set `NGPU=1`.**
95
+
96
+ We provide several [config files](https://github.com/fla-org/flame/tree/main/configs) for different models.
97
+ By default, the learning rate is set to 1e-3 with a cosine scheduler. Other schedulers, such as WSD (wsd), are also supported.
98
+
99
+ **Key parameters:**
100
+ - `--lr_scheduler.decay_ratio`: The proportion of the steps allocated to the decay phase. The learning rate will remain stable after the warmup period and only start decaying during the last `decay_ratio` portion of the total training steps, which is known as the Warmup-Stable-Decay (WSD) schedule.
101
+ - `--lr_scheduler.warmup_steps`: The number of steps for the learning rate warmup phase.
102
+ - `--training.steps`: Total number of training steps.
103
+ - `--training.batch_size`: Batch size per device, must be 1 if `--training.varlen` is set.
104
+ - `--training.seq_len`: The length of each sequence in the batch, which is concatenated from multiple samples.
105
+ - `--training.context_len`: The max allowed length of a sample. For non-varlen mode, this is equivalent to `seq_len`.
106
+ - `--training.varlen`: Whether to conduct variable-length sequence training.
107
+ - `--training.gradient_accumulation_steps`: Number of gradient accumulation steps.
108
+
109
+ > [!WARNING]
110
+ > The total number of tokens processed per batch, referred to as `global_batch_size`, is calculated as batch_size × gradient_accumulation_steps × num_gpus.
111
+ > Each step processes `global_batch_size * seq_len` tokens.
112
+ > Monitor the value of `global_batch_size`, `warmup_steps`, and `steps` carefully when modifying any of the hyperparameters!
113
+
114
+ For a detailed explanation of all parameters, run:
115
+
116
+ ```sh
117
+ bash train.sh -h
118
+ ```
119
+
120
+ <details>
121
+ <summary>Usage</summary>
122
+
123
+ ```py
124
+ options:
125
+ -h, --help show this help message and exit
126
+ --job.config_file JOB.CONFIG_FILE
127
+ Job config file
128
+ --job.dump_folder JOB.DUMP_FOLDER
129
+ Folder to dump job outputs
130
+ --job.description JOB.DESCRIPTION
131
+ Description of the job
132
+ --job.use_for_integration_test
133
+ Add this config to the integration test suite
134
+ --job.print_args Print the args to terminal
135
+ --model.config MODEL.CONFIG
136
+ Path to the model config
137
+ --model.norm_type MODEL.NORM_TYPE
138
+ Type of layer normalization to use [layernorm,
139
+ np_layernorm, rmsnorm, fused_rmsnorm]
140
+ --model.tokenizer_path MODEL.TOKENIZER_PATH
141
+ Tokenizer path
142
+ --profiling.enable_profiling
143
+ Whether to enable pytorch profiler
144
+ --profiling.save_traces_folder PROFILING.SAVE_TRACES_FOLDER
145
+ Trace files location
146
+ --profiling.profile_freq PROFILING.PROFILE_FREQ
147
+ How often to collect profiler traces, in iterations
148
+ --profiling.enable_memory_snapshot
149
+ Whether to dump memory snapshot
150
+ --profiling.save_memory_snapshot_folder PROFILING.SAVE_MEMORY_SNAPSHOT_FOLDER
151
+ Memeory snapshot files location
152
+ --optimizer.name OPTIMIZER.NAME
153
+ Optimizer to use
154
+ --optimizer.eps OPTIMIZER.EPS
155
+ Epsilon value for the optimizer.
156
+ --optimizer.fused Whether the fused implementation(CUDA only) is used.
157
+ --optimizer.scheduler {wsd,cosine,linear}
158
+ Scheduler to use. Currently supported: wsd, cosine,
159
+ and linear.
160
+ --optimizer.lr OPTIMIZER.LR
161
+ Learning rate to use
162
+ --optimizer.min_lr_ratio OPTIMIZER.MIN_LR_RATIO
163
+ Min lr ratio for lr scheduler
164
+ --optimizer.early_step_in_backward
165
+ Whether to apply optimizer in the backward. Caution,
166
+ optimizer_in_backward is not compatible with gradients
167
+ clipping, users should not call
168
+ register_post_accumulate_grad_hook after the optimizer
169
+ is built.
170
+ --training.batch_size TRAINING.BATCH_SIZE
171
+ Batch size
172
+ --training.seq_len TRAINING.SEQ_LEN
173
+ Sequence length
174
+ --training.context_len TRAINING.CONTEXT_LEN
175
+ Max length allowed for each sequence
176
+ --training.varlen Whether to take sequences of variable length as input
177
+ --training.warmup_steps TRAINING.WARMUP_STEPS
178
+ Steps for lr scheduler warmup, normally 1/5 of
179
+ --training.steps
180
+ --training.gradient_accumulation_steps TRAINING.GRADIENT_ACCUMULATION_STEPS
181
+ Number of steps to accumulate gradients before
182
+ updating parameters
183
+ --training.steps TRAINING.STEPS
184
+ How many train steps to run
185
+ --training.max_norm TRAINING.MAX_NORM
186
+ Max norm for gradient clipping
187
+ --training.skip_nan_inf
188
+ Skip batch updates when NaN or INF gradients are
189
+ encountered during training
190
+ --training.dataset TRAINING.DATASET
191
+ Dataset to use, with comma separated values
192
+ --training.dataset_name TRAINING.DATASET_NAME
193
+ The name of the dataset config, with comma separated
194
+ values if provided
195
+ --training.dataset_split TRAINING.DATASET_SPLIT
196
+ Dataset split to use, with comma separated values if
197
+ provided
198
+ --training.data_dir TRAINING.DATA_DIR
199
+ Data dirs to use, with comma separated values if
200
+ provided
201
+ --training.data_files TRAINING.DATA_FILES
202
+ Data files to use, with comma separated values if
203
+ provided
204
+ --training.data_probs TRAINING.DATA_PROBS
205
+ Data sampling probabilities, with comma separated
206
+ values if provided
207
+ --training.streaming Whether to load dataset in streaming mode, used for
208
+ huge dataset
209
+ --training.num_workers TRAINING.NUM_WORKERS
210
+ Number of subprocesses to use for data loading. 0
211
+ means that the data will be loaded in the main
212
+ process.
213
+ --training.prefetch_factor TRAINING.PREFETCH_FACTOR
214
+ Number of batches loaded in advance by each worker.2
215
+ means there will be a total of 2 * num_workers batches
216
+ prefetched across all workers.
217
+ --training.data_parallel_replicate_degree TRAINING.DATA_PARALLEL_REPLICATE_DEGREE
218
+ The `data_parallel_replicate_degree` argument
219
+ specifies the degree of data parallelism for weight
220
+ replication. When this value is greater than 1,
221
+ weights will be replicated across
222
+ `data_parallel_replicate_degree` ranks. If
223
+ `data_parallel_shard_degree` is also greater than 1,
224
+ the parallelism method used is HSDP (Hybrid Sharded
225
+ Data Parallelism). Otherwise, the parallelism method
226
+ used is DDP (Distributed Data Parallelism). 1 means
227
+ disabled.
228
+ --training.data_parallel_shard_degree TRAINING.DATA_PARALLEL_SHARD_DEGREE
229
+ The `data_parallel_shard_degree` argument specifies
230
+ the degree of data parallelism for weight sharding.
231
+ When this value is greater than 1, weights will be
232
+ sharded across `data_parallel_shard_degree` ranks. If
233
+ `data_parallel_replicate_degree` is also greater than
234
+ 1, the parallelism method used is HSDP (Hybrid Sharded
235
+ Data Parallelism). Otherwise, the parallelism method
236
+ used is FSDP (Fully Sharded Data Parallelism). -1
237
+ means leftover ranks will be used (After
238
+ DP_REPLICATE/SP/PP). Note that only
239
+ `data_parallel_shard_degree` can be negative. 1 means
240
+ disabled.
241
+ --training.enable_cpu_offload
242
+ Whether to apply CPU offloading of parameters,
243
+ gradients, and optimizer states in FSDP
244
+ --training.tensor_parallel_degree TRAINING.TENSOR_PARALLEL_DEGREE
245
+ Tensor Parallelism degree. 1 means disabled.
246
+ --training.disable_loss_parallel
247
+ Whether to apply loss parallel when sequence parallel
248
+ is enabled
249
+ --training.mixed_precision_param {bfloat16,float32}
250
+ torch dtype to use for parameters when applying mixed
251
+ precision via FSDP. This feature only takes effect
252
+ when data_parallel_shard_degree > 1
253
+ --training.mixed_precision_reduce {float32}
254
+ torch dtype to use for reductions when applying mixed
255
+ precision via FSDP. This feature only takes effect
256
+ when data_parallel_shard_degree > 1
257
+ --training.compile Whether to compile the model
258
+ --training.gc_freq TRAINING.GC_FREQ
259
+ Python garbage control scheduling interval, in steps
260
+ --training.seed TRAINING.SEED
261
+ Choose the base RNG seed used for training
262
+ --training.deterministic
263
+ Use deterministic algorithms wherever possible, may be
264
+ slower
265
+ --metrics.log_freq METRICS.LOG_FREQ
266
+ How often to log metrics to TensorBoard, in iterations
267
+ --metrics.enable_tensorboard
268
+ Whether to log metrics to TensorBoard
269
+ --metrics.disable_color_printing
270
+ Whether to disable color printing in logs
271
+ --metrics.save_tb_folder METRICS.SAVE_TB_FOLDER
272
+ Folder to dump TensorBoard states
273
+ --metrics.rank_0_only
274
+ Whether to save TensorBoard metrics only for rank 0 or
275
+ for all ranks. When pipeline_parallel_degree is > 1,
276
+ this option uses the 0th rank of the last stage
277
+ pipeline group, which is the only stage that computes
278
+ loss metrics.
279
+ --metrics.enable_wandb
280
+ Whether to log metrics to Weights & Biases
281
+ --experimental.enable_async_tensor_parallel
282
+ Whether to apply async tensor parallel (currently only
283
+ effective when compile is enabled)
284
+ --experimental.pipeline_parallel_degree EXPERIMENTAL.PIPELINE_PARALLEL_DEGREE
285
+ Pipeline Parallelism degree, or number of ranks. 1
286
+ means disabled. If using looped schedules, this still
287
+ specifies the number of physical ranks, not the number
288
+ of stages. Stages per rank are inferred from split
289
+ points degree, and schedule.
290
+ --experimental.pipeline_parallel_split_points EXPERIMENTAL.PIPELINE_PARALLEL_SPLIT_POINTS [EXPERIMENTAL.PIPELINE_PARALLEL_SPLIT_POINTS ...]
291
+ Specify comma-separated names of modules to use as the
292
+ beginning of a split point. e.g. "layers.0,layers.2"
293
+ will cause the model to be split into 3 stages, the
294
+ first containing all the layers up to layers.0, the
295
+ second containing layers.0 and up to layers.2, the
296
+ third containing layers.2 and all the remaining
297
+ layers. Note: fully-automated splitting may be enabled
298
+ in the future, but currently the split points must be
299
+ specified manually.
300
+ --experimental.pipeline_parallel_schedule EXPERIMENTAL.PIPELINE_PARALLEL_SCHEDULE
301
+ Specify the Pipeline Parallel schedule to use. The
302
+ supported schedules are: https://github.com/pytorch/py
303
+ torch/blob/de4c2a3b4e89d96334dc678d1c3f2ae51a6630a0/to
304
+ rch/distributed/pipelining/schedules.py#L2161. The
305
+ schedule must be compatible with the split points and
306
+ stages_per_rank. Looped schedules (e.g.
307
+ Interleaved1F1B) require specifying
308
+ pipeline_parallel_degree = number of ranks, and
309
+ split_points = number of stages - 1
310
+ --experimental.pipeline_parallel_schedule_csv EXPERIMENTAL.PIPELINE_PARALLEL_SCHEDULE_CSV
311
+ Specify the path to the pipeline parallel schedule csv
312
+ file to use. The pipeline_parallel_schedule argument
313
+ must be either PipelineScheduleSingle,
314
+ PipelineScheduleMulti, or _PipelineScheduleRuntime.
315
+ --experimental.pipeline_parallel_microbatches EXPERIMENTAL.PIPELINE_PARALLEL_MICROBATCHES
316
+ How many microbatches to split the global training
317
+ batch into when using pipeline parallelism. The global
318
+ training batch size must be evenly divisible by the
319
+ number of microbatches. The default value will be the
320
+ number of pipeline stages, if unspecified.
321
+ --experimental.enable_compiled_autograd
322
+ Enable CompiledAutograd to compile the backward.
323
+ --experimental.context_parallel_degree EXPERIMENTAL.CONTEXT_PARALLEL_DEGREE
324
+ Context parallelism degree. 1 means disabled.
325
+ --experimental.context_parallel_rotate_method EXPERIMENTAL.CONTEXT_PARALLEL_ROTATE_METHOD
326
+ The collective to use in context parallel SDPA for kv
327
+ shards exchange. 'allgather' means to all-gather all
328
+ kv shards on ranks after the first sub-SDPA
329
+ computation, 'alltoall' means to all-to-all shuffle
330
+ the kv shards. The default value is 'allgather'.
331
+ --checkpoint.enable_checkpoint
332
+ Whether to enable checkpoint
333
+ --checkpoint.folder CHECKPOINT.FOLDER
334
+ The folder to store the checkpoints. When
335
+ enable_checkpoint is set to true, checkpoints will be
336
+ in {--job.dump_folder}/{--checkpoint.folder}.
337
+ --checkpoint.interval_type CHECKPOINT.INTERVAL_TYPE
338
+ Checkpointing interval unit of measurement ['step',
339
+ 'seconds']
340
+ --checkpoint.interval CHECKPOINT.INTERVAL
341
+ Checkpointing interval, in steps or seconds depending
342
+ on --checkpoint.interval_type
343
+ --checkpoint.model_weights_only
344
+ When model_weights_only=True, only model weights will
345
+ be saved at the end of training. With this,
346
+ checkpoints can be loaded using `torch.load(...,
347
+ weights_only=True)` after conversion. When
348
+ model_weights_only=False, the full checkpoint will be
349
+ saved. A full checkpoint includes model, optimizer and
350
+ train_state, which can be used to resume training. The
351
+ default value is false.
352
+ --checkpoint.export_dtype {float16,bfloat16,float32}
353
+ Converts to the specified precision when training
354
+ completes and model_weights_only=true. Currently
355
+ supports float32, float16, and bfloat16. The default
356
+ value is float32.
357
+ --checkpoint.create_seed_checkpoint
358
+ Initializes the full model without applying
359
+ parallelisms, and then saves it as a seed checkpoint.
360
+ Note: requires user to call train.py without
361
+ specifying any parallelisms, e.g. NGPU=1. Could be
362
+ implemented as a separate script, but this way shares
363
+ more code.
364
+ --checkpoint.async_mode CHECKPOINT.ASYNC_MODE
365
+ Which async checkpoint mode to use. Currently there
366
+ are 3 different modes. 1. "disabled": synchronized
367
+ checkpointing will be used. 2. "async":
368
+ torch.distributed.checkpoint.async_save will be used.
369
+ 1. "async_with_pinned_mem": this option utilizes a
370
+ dedicated pinned memory space and creates a separate
371
+ process for faster GPU->CPU transfer performance and
372
+ eliminating GIL contention. The cost is increased CPU
373
+ memory usage. If insufficient CPU memory is available,
374
+ performance may degrade due to memory paging. For most
375
+ users, "async" should suffice as the performance
376
+ overhead is typically small (on the order of tens of
377
+ seconds) compared to checkpointing frequency. This
378
+ mode can be employed to pursue near-zero checkpointing
379
+ times (e.g., < 1 second) given appropriate hardware
380
+ support such as ample CPU memory and fast PCIe.
381
+ "disabled" is the default mode.
382
+ --checkpoint.keep_latest_k CHECKPOINT.KEEP_LATEST_K
383
+ Keeps only the latest k checkpoints, and purging older
384
+ ones. If 0, keep all checkpoints. 0 is the default
385
+ value.
386
+ --checkpoint.load_step CHECKPOINT.LOAD_STEP
387
+ Load the checkpoint at the specified step. If -1, load
388
+ the latest checkpoint.
389
+ --float8.enable_float8_linear
390
+ If true, swaps `torch.nn.Linear` with `Float8Linear`.
391
+ This feature requires you to install 'torchao' which
392
+ can be found here: https://github.com/pytorch/ao
393
+ --float8.enable_fsdp_float8_all_gather
394
+ Whether enable float8 all-gather in FSDP
395
+ --float8.precompute_float8_dynamic_scale_for_fsdp
396
+ Whether precompute float8 scales dynamically for FSDP
397
+ --float8.scaling_type_input {dynamic,delayed}
398
+ float8 scaling for input, dynamic (default) or delayed
399
+ --float8.scaling_type_weight FLOAT8.SCALING_TYPE_WEIGHT
400
+ float8 scaling for input, dynamic (default) or delayed
401
+ --float8.scaling_type_grad_output FLOAT8.SCALING_TYPE_GRAD_OUTPUT
402
+ float8 scaling for input, dynamic (default) or delayed
403
+ --comm.init_timeout_seconds COMM.INIT_TIMEOUT_SECONDS
404
+ Timeout for communication operations, during
405
+ initialization and first train step.
406
+ --comm.train_timeout_seconds COMM.TRAIN_TIMEOUT_SECONDS
407
+ Timeout for communication operations after the first
408
+ train step -- usually a tighter bound than during
409
+ initialization.
410
+ --comm.trace_buf_size COMM.TRACE_BUF_SIZE
411
+ Flight recorder ring buffer size, >0 means recording
412
+ by default, 0 means disabled
413
+ --memory_estimation.enabled
414
+ Whether to estimate memory usage for FSDP
415
+ --memory_estimation.disable_fake_mode
416
+ Whether to estimate memory under FakeTensorMode
417
+ ```
418
+ </details>
419
+
420
+ ### Training with variable-length inputs
421
+ When you set the `--training.varlen` flag, you're enabling a more efficient training method that packs multiple documents together into a single long sequence, eliminating the need for padding.
422
+ This is particularly useful when your dataset contains documents of varying lengths.
423
+ Let's break down how `--training.seq_len` and `--training.context_len` work in this mode.
424
+
425
+ * `--training.seq_len` (Packed Sequence Length): This is the total length of the final sequence fed to the model on one device. Instead of processing one document at a time, the dataloader takes multiple documents (each split to sequences no longer than `context_len`), concatenates them end-to-end, and creates a single long sequence of length `seq_len`.
426
+ * `--training.context_len` (Sample Length): This parameter defines the maximum number of tokens for a single document or sample. If a document from the dataset is longer than `context_len`, it will be truncated. For example, if `--training.context_len` is set to 4,096, a document with 5,000 tokens will be cut down to its first 4,096 tokens, leaving the left tokens as another independent sequence, while a document with 3000 tokens remains unchanged.
427
+
428
+ ### Training with `torch.compile`
429
+
430
+ Starting from `torch 2.0`, `torch.compile` has been introduced as a new feature to seamlessly accelerate training processes.
431
+ In `flame`, one can simply enable `torch.compile` by adding `--training.compile` flag to your training script.
432
+
433
+ However, `fla` has integrated numerous fused kernels for acceleration, which may potentially conflict with `torch.compile`.
434
+ We are actively working on resolving these issues to make compilation transparent to users.
435
+ In the meantime, please ensure you are using the latest dependencies.
436
+
437
+ Specifically, **we recommend using `torch>=2.6` and `triton>=3.0`**.
438
+
439
+ ### Training with multiple datasets
440
+
441
+ If you wish to train a model with all-round capabilities (e.g., code, math, and multilingual ability), it's necessary to train on multiple datasets.
442
+ `flame` allows training with multiple datasets easily.
443
+ For example, you can specify the following arguments to train on 6 datasets with different proportions:
444
+
445
+ ```sh
446
+ --training.dataset HuggingFaceFW/fineweb-edu,opencsg/Fineweb-Edu-Chinese-V2.1,OpenCoder-LLM/opc-fineweb-code-corpus,math-ai/AutoMathText,EleutherAI/proof-pile-2,OpenCoder-LLM/opc-fineweb-math-corpus \
447
+ --training.data_probs 0.6,0.15,0.15,0.014,0.058,0.028 \
448
+ ```
449
+
450
+ ### ~Finalizing training~
451
+
452
+ > [!NOTE]
453
+ > We have done this conversion automatically in the training script since our latest updates.
454
+
455
+ Once training is complete, you may want to convert the distributed checkpoints (DCPs) into the 🤗 format for broader use.
456
+ To facilitate this, we provide a straightforward conversion script:
457
+
458
+ ```sh
459
+ python -m flame.utils.convert_dcp_to_hf --path <path_to_model> --step <step> --config <path_to_config> --tokenizer <path_to_tokenizer>
460
+ ```
461
+ After this, your model will be in the 🤗 format, ready to be shared or deployed.
462
+ You can then easily publish your model using the `huggingface_hub` for wider accessibility.
463
+
464
+ ### Continual training
465
+
466
+ If you wish to build upon a strong pre-trained model (in 🤗 format) and continue training, we also offer a script to convert the 🤗 format model back into DCP format.
467
+ This allows you to seamlessly resume training with `flame`.
468
+ ```sh
469
+ python -m flame.utils.convert_hf_to_dcp --model <path_to_hf> --checkpoint <path_to_dcp/checkpoint/step-0>
470
+ ```
471
+ Here, `<path_to_dcp>` is the directory where your distributed checkpoints will be stored.
472
+ The checkpoint is intentionally saved at `<step-0>` within the checkpoint folder to ensure it is loadable by `flame` during the initial training step, similar to how a seed checkpoint is handled.
473
+
474
+ Once the conversion is complete, you can proceed with training using `flame` as usual, continuing from where the pretrained model left off.
475
+
476
+ ## Multi-node training
477
+
478
+ If you have access to multi-node GPUs, consider leveraging them for optimal performance.
479
+ This process is straightforward and well-documented in the PyTorch [docs](https://pytorch.org/docs/stable/elastic/run.html).
480
+
481
+ To set up multi-node training:
482
+ * Set the environment variables `MASTER_ADDR=<ip>` and `MASTER_PORT=<port>` before running the training script across all nodes.
483
+ * If you're using a job scheduler like Slurm, it will handle these variables for you.
484
+
485
+ `torchtitan` provides a [Slurm script](https://github.com/pytorch/torchtitan/blob/main/multinode_trainer.slurm) for multi-node training, which you can use as a reference or starting point.
486
+
487
+ ## Custom models
488
+
489
+ `flame` supports custom model architectures through seamless integration with the Hugging Face `transformers` library. To add your own model:
490
+
491
+ 1. Create a new model directory under `custom_models/` (see `custom_models/sba` for a complete example)
492
+ 2. Implement your model classes and configuration:
493
+ - Define a config class inheriting from `PretrainedConfig` (see `custom_models/sba/config_sba.py` for an example)
494
+ - Create model classes inheriting from `PreTrainedModel` (see `custom_models/sba/modeling_sba.py` for an example)
495
+ 3. Register your models in `__init__.py`:
496
+ - Import your model classes and config classes
497
+ - Register your models with the `AutoModelForCausalLM`, `AutoModel` and `AutoConfig` classes (see `custom_models/sba/__init__.py` for an example)
498
+ 4. Create a config file for your custom model, just need to specify the `model_type` to the one you just named for your custom model (example: `configs/sba_340m.json`).
499
+ 5. Training is extremely simple, you can just use the `flame.train.py` script to train your custom model.
500
+
501
+
502
+
503
+
504
+
505
+
506
+
507
+ ## Citation
508
+
509
+ If you find `flame` helpful for your work, please consider citing it.
510
+
511
+ ```bib
512
+ @software{yang2025flame,
513
+ title = {Flame: Flash Language Modeling Made Easy},
514
+ author = {Zhang, Yu and Yang, Songlin},
515
+ url = {https://github.com/fla-org/flame},
516
+ month = jan,
517
+ year = {2025}
518
+ }
519
+ ```
gated_deltanet_1b_v3/gated_deltanet_1b_adamw_lr1e_3_b1_0_9_b2_0_99_eps_1e_8_20260503_013622/exp_data/checkpoint/step-30720/.metadata ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8f68afac89910e26060ee26576f22d85b70890770c105bbc9d3cf15f6c3912e0
3
+ size 2163627
gated_deltanet_1b_v3/gated_deltanet_1b_adamw_lr1e_3_b1_0_9_b2_0_99_eps_1e_8_20260503_013622/exp_data/checkpoint/step-30720/__0_0.distcp ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1f368aa85e82d8df07b9207a69db82d4d01ce3f4832e0b92dc51eb9e55609c5c
3
+ size 2264220780
gated_deltanet_1b_v3/gated_deltanet_1b_adamw_lr1e_3_b1_0_9_b2_0_99_eps_1e_8_20260503_013622/exp_data/checkpoint/step-30720/__1_0.distcp ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:259c9cbad5e6e1c7331daea4e737eba3bc8f7bbd34a8dd63d258a60a0d44e46d
3
+ size 2039609256
gated_deltanet_1b_v3/gated_deltanet_1b_adamw_lr1e_3_b1_0_9_b2_0_99_eps_1e_8_20260503_013622/exp_data/checkpoint/step-30720/__2_0.distcp ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e89f44b912e3fbf16b44351acbaae32f1ad152359e2bbad03ceba1c24696b192
3
+ size 2029443450
gated_deltanet_1b_v3/gated_deltanet_1b_adamw_lr1e_3_b1_0_9_b2_0_99_eps_1e_8_20260503_013622/exp_data/checkpoint/step-30720/__3_0.distcp ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e073ec4d2a35983b6a54f86b71edbdf9d56de48c0643a86fdffcd4eb7a4d7432
3
+ size 2267232916
gated_deltanet_1b_v3/gated_deltanet_1b_adamw_lr1e_3_b1_0_9_b2_0_99_eps_1e_8_20260503_013622/exp_data/checkpoint/step-30720/__4_0.distcp ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0ee9cef3c9295916bcee37b80857a1817c4ba99fefb65d4b30ad46675e1996a5
3
+ size 2040124587
gated_deltanet_1b_v3/gated_deltanet_1b_adamw_lr1e_3_b1_0_9_b2_0_99_eps_1e_8_20260503_013622/exp_data/checkpoint/step-30720/__5_0.distcp ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b5791a15bad03274be4358d10bd72a5998f43a12c5d27614d018f42baec1b226
3
+ size 2038161975
gated_deltanet_1b_v3/gated_deltanet_1b_adamw_lr1e_3_b1_0_9_b2_0_99_eps_1e_8_20260503_013622/exp_data/checkpoint/step-30720/__6_0.distcp ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:eaf7bae1b5a5d5a487e003d8e79ad3f1b8faefbccc7881560590217802fb1e5d
3
+ size 2028462328
gated_deltanet_1b_v3/gated_deltanet_1b_adamw_lr1e_3_b1_0_9_b2_0_99_eps_1e_8_20260503_013622/exp_data/checkpoint/step-30720/__7_0.distcp ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:39a90c47888aae8230b79460cd72bfbc0ba8f737f6267b5d8d6d6396d7d99b3e
3
+ size 2038766697
gated_deltanet_1b_v3/gated_deltanet_1b_adamw_lr1e_3_b1_0_9_b2_0_99_eps_1e_8_20260503_013622/exp_data/config.json ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "allow_neg_eigval": false,
3
+ "architectures": [
4
+ "GatedDeltaNetForCausalLM"
5
+ ],
6
+ "attn": null,
7
+ "attn_mode": "chunk",
8
+ "bos_token_id": 1,
9
+ "conv_size": 4,
10
+ "dtype": "float32",
11
+ "eos_token_id": 2,
12
+ "expand_v": 2,
13
+ "fuse_cross_entropy": true,
14
+ "fuse_linear_cross_entropy": false,
15
+ "fuse_norm": true,
16
+ "fuse_swiglu": true,
17
+ "head_dim": 256,
18
+ "hidden_act": "swish",
19
+ "hidden_ratio": 4,
20
+ "hidden_size": 2048,
21
+ "initializer_range": 0.02,
22
+ "intermediate_size": null,
23
+ "max_position_embeddings": 2048,
24
+ "model_type": "gated_deltanet",
25
+ "norm_eps": 1e-06,
26
+ "num_heads": 6,
27
+ "num_hidden_layers": 21,
28
+ "num_v_heads": null,
29
+ "tie_word_embeddings": false,
30
+ "transformers_version": "4.57.6",
31
+ "use_cache": true,
32
+ "use_gate": true,
33
+ "use_l2warp": false,
34
+ "use_short_conv": true,
35
+ "vocab_size": 32000
36
+ }
gated_deltanet_1b_v3/gated_deltanet_1b_adamw_lr1e_3_b1_0_9_b2_0_99_eps_1e_8_20260503_013622/exp_data/configs/blt_transformer_1000hash.json ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_type": "blt",
3
+ "vocab_size": 260,
4
+ "max_position_embeddings": 4096,
5
+ "initializer_range": 0.02,
6
+ "tie_word_embeddings": false,
7
+ "patch_in_forward": true,
8
+ "patch_size": 4,
9
+ "patching_mode": "entropy",
10
+ "patching_threshold": 1.335442066192627,
11
+ "patching_batch_size": 1,
12
+ "max_patch_length": null,
13
+ "patching_device": "cuda",
14
+ "realtime_patching": true,
15
+ "patching_threshold_add": null,
16
+ "monotonicity": false,
17
+ "cross_attn_k": 2,
18
+ "encoder_hash_byte_group_size": [3, 4, 5, 6, 7, 8],
19
+ "encoder_hash_byte_group_vocab": 1000,
20
+ "encoder_hash_byte_group_nb_functions": 1,
21
+ "patcher_config": {
22
+ "model_type": "blt_patcher",
23
+ "vocab_size": 260,
24
+ "hidden_size": 512,
25
+ "num_hidden_layers": 7,
26
+ "num_attention_heads": 8,
27
+ "num_key_value_heads": 8,
28
+ "max_position_embeddings": 8192,
29
+ "rms_norm_eps": 1e-5,
30
+ "dropout": 0.0,
31
+ "intermediate_size": 1365,
32
+ "hidden_act": "silu",
33
+ "initializer_range": 0.02,
34
+ "rope_parameters": {"rope_type": "default",
35
+ "rope_theta": 500000
36
+ }
37
+ },
38
+ "encoder_config": {
39
+ "model_type": "blt_local_encoder",
40
+ "vocab_size": 260,
41
+ "hidden_size": 512,
42
+ "hidden_size_global": 1024,
43
+ "num_hidden_layers": 1,
44
+ "num_attention_heads": 8,
45
+ "num_key_value_heads": 8,
46
+ "head_dim": 64,
47
+ "intermediate_size": 1365,
48
+ "rms_norm_eps": 1e-5,
49
+ "dropout": 0.0,
50
+ "max_position_embeddings": 24576,
51
+ "cross_attn_all_layers": false,
52
+ "cross_attn_k": 2,
53
+ "hidden_act": "silu",
54
+ "initializer_range": 0.02,
55
+ "rope_parameters": {"rope_type": "default",
56
+ "rope_theta": 500000
57
+ }
58
+ },
59
+ "decoder_config": {
60
+ "model_type": "blt_local_decoder",
61
+ "vocab_size": 260,
62
+ "hidden_size": 512,
63
+ "hidden_size_global": 1024,
64
+ "num_hidden_layers": 9,
65
+ "num_attention_heads": 8,
66
+ "num_key_value_heads": 8,
67
+ "head_dim": 64,
68
+ "intermediate_size": 1365,
69
+ "rms_norm_eps": 1e-5,
70
+ "dropout": 0.0,
71
+ "max_position_embeddings": 24576,
72
+ "cross_attn_all_layers": true,
73
+ "cross_attn_k": 2,
74
+ "hidden_act": "silu",
75
+ "initializer_range": 0.02,
76
+ "rope_parameters": {"rope_type": "default",
77
+ "rope_theta": 500000
78
+ }
79
+ },
80
+ "global_config": {
81
+ "model_type": "blt_global_transformer",
82
+ "hidden_size": 1024,
83
+ "num_hidden_layers": 25,
84
+ "num_attention_heads": 8,
85
+ "num_key_value_heads": 8,
86
+ "head_dim": 128,
87
+ "intermediate_size": 2731,
88
+ "rms_norm_eps": 1e-5,
89
+ "dropout": 0.0,
90
+ "max_position_embeddings": 4096,
91
+ "hidden_act": "silu",
92
+ "initializer_range": 0.02,
93
+ "rope_parameters": {"rope_type": "default",
94
+ "rope_theta": 500000
95
+ },
96
+ "encoder_cross_output_size": null
97
+ }
98
+ }
gated_deltanet_1b_v3/gated_deltanet_1b_adamw_lr1e_3_b1_0_9_b2_0_99_eps_1e_8_20260503_013622/exp_data/configs/blt_transformer_1_5B.json ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_type": "blt",
3
+ "vocab_size": 260,
4
+ "max_position_embeddings": 4096,
5
+ "initializer_range": 0.02,
6
+ "tie_word_embeddings": false,
7
+ "patch_in_forward": true,
8
+ "patch_size": 4,
9
+ "patching_mode": "entropy",
10
+ "patching_threshold": 1.335442066192627,
11
+ "patching_batch_size": 1,
12
+ "max_patch_length": null,
13
+ "patching_device": "cuda",
14
+ "realtime_patching": true,
15
+ "patching_threshold_add": null,
16
+ "monotonicity": false,
17
+ "cross_attn_k": 2,
18
+ "encoder_hash_byte_group_size": [3, 4, 5, 6, 7, 8],
19
+ "encoder_hash_byte_group_vocab": 500,
20
+ "encoder_hash_byte_group_nb_functions": 1,
21
+ "patcher_config": {
22
+ "model_type": "blt_patcher",
23
+ "vocab_size": 260,
24
+ "hidden_size": 768,
25
+ "num_hidden_layers": 7,
26
+ "num_attention_heads": 12,
27
+ "num_key_value_heads": 12,
28
+ "max_position_embeddings": 8192,
29
+ "rms_norm_eps": 1e-5,
30
+ "dropout": 0.0,
31
+ "intermediate_size": 2048,
32
+ "hidden_act": "silu",
33
+ "initializer_range": 0.02,
34
+ "rope_parameters": {"rope_type": "default",
35
+ "rope_theta": 500000
36
+ }
37
+ },
38
+ "encoder_config": {
39
+ "model_type": "blt_local_encoder",
40
+ "vocab_size": 260,
41
+ "hidden_size": 1024,
42
+ "hidden_size_global": 2048,
43
+ "num_hidden_layers": 1,
44
+ "num_attention_heads": 16,
45
+ "num_key_value_heads": 16,
46
+ "head_dim": 64,
47
+ "intermediate_size": 2816,
48
+ "rms_norm_eps": 1e-5,
49
+ "dropout": 0.0,
50
+ "max_position_embeddings": 24576,
51
+ "cross_attn_all_layers": false,
52
+ "cross_attn_k": 2,
53
+ "hidden_act": "silu",
54
+ "initializer_range": 0.02,
55
+ "rope_parameters": {"rope_type": "default",
56
+ "rope_theta": 500000
57
+ }
58
+ },
59
+
60
+ "decoder_config": {
61
+ "model_type": "blt_local_decoder",
62
+ "vocab_size": 260,
63
+ "hidden_size": 1024,
64
+ "hidden_size_global": 2048,
65
+ "num_hidden_layers": 9,
66
+ "num_attention_heads": 16,
67
+ "num_key_value_heads": 16,
68
+ "head_dim": 64,
69
+ "intermediate_size": 2816,
70
+ "rms_norm_eps": 1e-5,
71
+ "dropout": 0.0,
72
+ "max_position_embeddings": 24576,
73
+ "cross_attn_all_layers": true,
74
+ "cross_attn_k": 2,
75
+ "hidden_act": "silu",
76
+ "initializer_range": 0.02,
77
+ "rope_parameters": {"rope_type": "default",
78
+ "rope_theta": 500000
79
+ }
80
+ },
81
+ "global_config": {
82
+ "model_type": "blt_global_transformer",
83
+ "hidden_size": 2048,
84
+ "num_hidden_layers": 25,
85
+ "num_attention_heads": 16,
86
+ "num_key_value_heads": 16,
87
+ "head_dim": 128,
88
+ "intermediate_size": 5632,
89
+ "rms_norm_eps": 1e-5,
90
+ "dropout": 0.0,
91
+ "max_position_embeddings": 4096,
92
+ "hidden_act": "silu",
93
+ "initializer_range": 0.02,
94
+ "rope_parameters": {"rope_type": "default",
95
+ "rope_theta": 500000
96
+ },
97
+ "encoder_cross_output_size": null
98
+ }
99
+ }
gated_deltanet_1b_v3/gated_deltanet_1b_adamw_lr1e_3_b1_0_9_b2_0_99_eps_1e_8_20260503_013622/exp_data/configs/blt_transformer_380M.json ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_type": "blt",
3
+ "vocab_size": 260,
4
+ "max_position_embeddings": 4096,
5
+ "initializer_range": 0.02,
6
+ "tie_word_embeddings": false,
7
+ "patch_in_forward": true,
8
+ "patch_size": 4,
9
+ "patching_mode": "entropy",
10
+ "patching_threshold": 1.335442066192627,
11
+ "patching_batch_size": 1,
12
+ "max_patch_length": null,
13
+ "patching_device": "cuda",
14
+ "realtime_patching": true,
15
+ "patching_threshold_add": null,
16
+ "monotonicity": false,
17
+ "cross_attn_k": 2,
18
+ "encoder_hash_byte_group_size": [3, 4, 5, 6, 7, 8],
19
+ "encoder_hash_byte_group_vocab": 500,
20
+ "encoder_hash_byte_group_nb_functions": 1,
21
+ "patcher_config": {
22
+ "model_type": "blt_patcher",
23
+ "vocab_size": 260,
24
+ "hidden_size": 512,
25
+ "num_hidden_layers": 7,
26
+ "num_attention_heads": 8,
27
+ "num_key_value_heads": 8,
28
+ "max_position_embeddings": 8192,
29
+ "rms_norm_eps": 1e-5,
30
+ "dropout": 0.0,
31
+ "intermediate_size": 1365,
32
+ "hidden_act": "silu",
33
+ "initializer_range": 0.02,
34
+ "rope_parameters": {"rope_type": "default",
35
+ "rope_theta": 500000
36
+ }
37
+ },
38
+ "encoder_config": {
39
+ "model_type": "blt_local_encoder",
40
+ "vocab_size": 260,
41
+ "hidden_size": 512,
42
+ "hidden_size_global": 1024,
43
+ "num_hidden_layers": 1,
44
+ "num_attention_heads": 8,
45
+ "num_key_value_heads": 8,
46
+ "head_dim": 64,
47
+ "intermediate_size": 1365,
48
+ "rms_norm_eps": 1e-5,
49
+ "dropout": 0.0,
50
+ "max_position_embeddings": 24576,
51
+ "cross_attn_all_layers": false,
52
+ "cross_attn_k": 2,
53
+ "hidden_act": "silu",
54
+ "initializer_range": 0.02,
55
+ "rope_parameters": {"rope_type": "default",
56
+ "rope_theta": 500000
57
+ }
58
+ },
59
+ "decoder_config": {
60
+ "model_type": "blt_local_decoder",
61
+ "vocab_size": 260,
62
+ "hidden_size": 512,
63
+ "hidden_size_global": 1024,
64
+ "num_hidden_layers": 9,
65
+ "num_attention_heads": 8,
66
+ "num_key_value_heads": 8,
67
+ "head_dim": 64,
68
+ "intermediate_size": 1365,
69
+ "rms_norm_eps": 1e-5,
70
+ "dropout": 0.0,
71
+ "max_position_embeddings": 24576,
72
+ "cross_attn_all_layers": true,
73
+ "cross_attn_k": 2,
74
+ "hidden_act": "silu",
75
+ "initializer_range": 0.02,
76
+ "rope_parameters": {"rope_type": "default",
77
+ "rope_theta": 500000
78
+ }
79
+ },
80
+ "global_config": {
81
+ "model_type": "blt_global_transformer",
82
+ "hidden_size": 1024,
83
+ "num_hidden_layers": 25,
84
+ "num_attention_heads": 8,
85
+ "num_key_value_heads": 8,
86
+ "head_dim": 128,
87
+ "intermediate_size": 2731,
88
+ "rms_norm_eps": 1e-5,
89
+ "dropout": 0.0,
90
+ "max_position_embeddings": 4096,
91
+ "hidden_act": "silu",
92
+ "initializer_range": 0.02,
93
+ "rope_parameters": {"rope_type": "default",
94
+ "rope_theta": 500000
95
+ },
96
+ "encoder_cross_output_size": null
97
+ }
98
+ }
gated_deltanet_1b_v3/gated_deltanet_1b_adamw_lr1e_3_b1_0_9_b2_0_99_eps_1e_8_20260503_013622/exp_data/configs/delta_net_1B.json ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attn": null,
3
+ "attn_mode": "chunk",
4
+ "bos_token_id": 1,
5
+ "conv_size": 4,
6
+ "eos_token_id": 2,
7
+ "expand_k": 1,
8
+ "expand_v": 1,
9
+ "fuse_cross_entropy": true,
10
+ "fuse_norm": true,
11
+ "hidden_act": "swish",
12
+ "hidden_ratio": 4,
13
+ "hidden_size": 2048,
14
+ "initializer_range": 0.02,
15
+ "intermediate_size": null,
16
+ "model_type": "delta_net",
17
+ "norm_eps": 1e-06,
18
+ "num_heads": 16,
19
+ "num_hidden_layers": 24,
20
+ "pad_token_id": 2,
21
+ "qk_activation": "silu",
22
+ "qk_norm": "l2",
23
+ "tie_word_embeddings": false,
24
+ "use_beta": true,
25
+ "use_cache": true,
26
+ "use_gate": false,
27
+ "use_output_norm": true,
28
+ "use_short_conv": true
29
+ }
gated_deltanet_1b_v3/gated_deltanet_1b_adamw_lr1e_3_b1_0_9_b2_0_99_eps_1e_8_20260503_013622/exp_data/configs/delta_net_340M.json ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attn_mode": "chunk",
3
+ "bos_token_id": 1,
4
+ "conv_size": 4,
5
+ "eos_token_id": 2,
6
+ "expand_k": 1,
7
+ "expand_v": 1,
8
+ "fuse_cross_entropy": true,
9
+ "hidden_act": "swish",
10
+ "hidden_ratio": 4,
11
+ "hidden_size": 1024,
12
+ "initializer_range": 0.02,
13
+ "intermediate_size": null,
14
+ "model_type": "delta_net",
15
+ "norm_eps": 1e-06,
16
+ "num_heads": 8,
17
+ "num_hidden_layers": 24,
18
+ "qk_activation": "silu",
19
+ "qk_norm": "l2",
20
+ "tie_word_embeddings": false,
21
+ "use_beta": true,
22
+ "use_cache": true,
23
+ "use_gate": false,
24
+ "use_output_norm": true,
25
+ "use_short_conv": true
26
+ }
gated_deltanet_1b_v3/gated_deltanet_1b_adamw_lr1e_3_b1_0_9_b2_0_99_eps_1e_8_20260503_013622/exp_data/configs/gated_deltanet_1B.json ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attn_mode": "chunk",
3
+ "bos_token_id": 1,
4
+ "conv_size": 4,
5
+ "eos_token_id": 2,
6
+ "expand_v": 2,
7
+ "fuse_cross_entropy": true,
8
+ "head_dim": 256,
9
+ "hidden_act": "swish",
10
+ "hidden_ratio": 4,
11
+ "hidden_size": 2048,
12
+ "initializer_range": 0.02,
13
+ "intermediate_size": null,
14
+ "model_type": "gated_deltanet",
15
+ "norm_eps": 1e-06,
16
+ "num_heads": 6,
17
+ "num_hidden_layers": 21,
18
+ "tie_word_embeddings": false,
19
+ "use_cache": true,
20
+ "use_gate": true,
21
+ "use_short_conv": true
22
+ }
gated_deltanet_1b_v3/gated_deltanet_1b_adamw_lr1e_3_b1_0_9_b2_0_99_eps_1e_8_20260503_013622/exp_data/configs/gated_deltanet_340M.json ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attn_mode": "chunk",
3
+ "bos_token_id": 1,
4
+ "conv_size": 4,
5
+ "eos_token_id": 2,
6
+ "expand_v": 2,
7
+ "fuse_cross_entropy": true,
8
+ "head_dim": 256,
9
+ "hidden_act": "swish",
10
+ "hidden_ratio": 4,
11
+ "hidden_size": 1024,
12
+ "initializer_range": 0.02,
13
+ "intermediate_size": null,
14
+ "model_type": "gated_deltanet",
15
+ "norm_eps": 1e-06,
16
+ "num_heads": 6,
17
+ "num_hidden_layers": 21,
18
+ "tie_word_embeddings": false,
19
+ "use_cache": true,
20
+ "use_gate": true,
21
+ "use_short_conv": true
22
+ }
gated_deltanet_1b_v3/gated_deltanet_1b_adamw_lr1e_3_b1_0_9_b2_0_99_eps_1e_8_20260503_013622/exp_data/configs/gated_deltanet_h_340M.json ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_type": "gated_deltanet",
3
+ "attn_mode": "chunk",
4
+ "hidden_size": 1024,
5
+ "num_hidden_layers": 21,
6
+ "head_dim": 256,
7
+ "num_heads": 6,
8
+ "expand_v": 2,
9
+ "hidden_ratio": 4,
10
+ "use_gate": true,
11
+ "use_short_conv": true,
12
+ "conv_size": 4,
13
+ "vocab_size": 32000,
14
+ "hidden_act": "swish",
15
+ "norm_eps": 1e-06,
16
+ "bos_token_id": 1,
17
+ "eos_token_id": 2,
18
+ "fuse_cross_entropy": true,
19
+ "initializer_range": 0.02,
20
+ "attn": {
21
+ "layers": [3, 7, 11, 15, 19],
22
+ "num_heads": 8,
23
+ "num_kv_heads": 1,
24
+ "window_size": 2048,
25
+ "rope_theta": 100000.0,
26
+ "qkv_bias": false
27
+ }
28
+ }
gated_deltanet_1b_v3/gated_deltanet_1b_adamw_lr1e_3_b1_0_9_b2_0_99_eps_1e_8_20260503_013622/exp_data/configs/gla_340M.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attn_mode": "chunk",
3
+ "bos_token_id": 1,
4
+ "clamp_min": null,
5
+ "eos_token_id": 2,
6
+ "expand_k": 0.5,
7
+ "expand_v": 1,
8
+ "fuse_cross_entropy": true,
9
+ "fuse_norm": true,
10
+ "hidden_act": "swish",
11
+ "hidden_ratio": 4,
12
+ "hidden_size": 1024,
13
+ "initializer_range": 0.02,
14
+ "intermediate_size": null,
15
+ "model_type": "gla",
16
+ "num_heads": 4,
17
+ "num_hidden_layers": 24,
18
+ "norm_eps": 1e-06,
19
+ "tie_word_embeddings": false,
20
+ "use_cache": true,
21
+ "use_gk": true,
22
+ "use_gv": false,
23
+ "vocab_size": 32000
24
+ }
gated_deltanet_1b_v3/gated_deltanet_1b_adamw_lr1e_3_b1_0_9_b2_0_99_eps_1e_8_20260503_013622/exp_data/configs/gla_7B.json ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attn": null,
3
+ "attn_mode": "chunk",
4
+ "bos_token_id": 1,
5
+ "eos_token_id": 2,
6
+ "expand_k": 0.5,
7
+ "expand_v": 1,
8
+ "fuse_cross_entropy": true,
9
+ "fuse_norm": true,
10
+ "hidden_act": "swish",
11
+ "hidden_ratio": 4,
12
+ "hidden_size": 4096,
13
+ "initializer_range": 0.02,
14
+ "intermediate_size": 11008,
15
+ "model_type": "gla",
16
+ "norm_eps": 1e-06,
17
+ "num_heads": 16,
18
+ "num_hidden_layers": 32,
19
+ "tie_word_embeddings": false,
20
+ "use_cache": true,
21
+ "use_gk": true,
22
+ "use_gv": false,
23
+ "use_output_gate": true,
24
+ "use_short_conv": false
25
+ }
gated_deltanet_1b_v3/gated_deltanet_1b_adamw_lr1e_3_b1_0_9_b2_0_99_eps_1e_8_20260503_013622/exp_data/configs/gsa_340M.json ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token_id": 1,
3
+ "conv_size": 4,
4
+ "eos_token_id": 2,
5
+ "expand_k": 1,
6
+ "expand_v": 1,
7
+ "elementwise_affine": false,
8
+ "feature_map": "swish",
9
+ "fuse_cross_entropy": true,
10
+ "fuse_norm": true,
11
+ "gate_logit_normalizer": 4,
12
+ "hidden_act": "swish",
13
+ "hidden_ratio": 4,
14
+ "hidden_size": 1024,
15
+ "initializer_range": 0.02,
16
+ "intermediate_size": null,
17
+ "model_type": "gsa",
18
+ "num_heads": 4,
19
+ "num_hidden_layers": 24,
20
+ "num_slots": 64,
21
+ "norm_eps": 1e-06,
22
+ "share_conv_kernel": true,
23
+ "tie_word_embeddings": false,
24
+ "use_cache": true,
25
+ "use_norm": true,
26
+ "use_output_gate": true,
27
+ "use_rope": false,
28
+ "use_short_conv": false
29
+ }
gated_deltanet_1b_v3/gated_deltanet_1b_adamw_lr1e_3_b1_0_9_b2_0_99_eps_1e_8_20260503_013622/exp_data/configs/mergenet_340M.json ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_type": "mergenet",
3
+ "vocab_size": 260,
4
+ "hidden_size": 1024,
5
+ "num_local_layers": 6,
6
+ "local_depth": 4,
7
+ "num_latent_layers": 12,
8
+ "num_heads": 16,
9
+ "num_kv_heads": 16,
10
+ "intermediate_size": 4096,
11
+ "hidden_act": "swish",
12
+ "max_position_embeddings": 8192,
13
+ "lambda_local": 4.0,
14
+ "dtem_window_size": 8,
15
+ "dtem_t": 1,
16
+ "dtem_feat_dim": null,
17
+ "use_softkmax": false,
18
+ "grid_bias_gamma": 1.0,
19
+ "W_infer": null,
20
+ "qkv_bias": true,
21
+ "qk_norm": false,
22
+ "rope_theta": 10000.0,
23
+ "norm_eps": 1e-6,
24
+ "initializer_range": 0.02,
25
+ "use_cache": true,
26
+ "pad_token_id": 0,
27
+ "bos_token_id": 1,
28
+ "eos_token_id": 2,
29
+ "tie_word_embeddings": false,
30
+ "phase": "phase2",
31
+ "drop_rate": 0.0,
32
+ "attn_drop_rate": 0.0,
33
+ "drop_path_rate": 0.1
34
+ }
gated_deltanet_1b_v3/gated_deltanet_1b_adamw_lr1e_3_b1_0_9_b2_0_99_eps_1e_8_20260503_013622/exp_data/configs/mergenet_64M.json ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_type": "mergenet",
3
+ "vocab_size": 32000,
4
+ "hidden_size": 512,
5
+ "num_local_layers": 4,
6
+ "local_depth": 4,
7
+ "num_latent_layers": 8,
8
+ "num_heads": 8,
9
+ "num_kv_heads": 8,
10
+ "intermediate_size": 2048,
11
+ "hidden_act": "swish",
12
+ "max_position_embeddings": 4096,
13
+ "lambda_local": 4.0,
14
+ "dtem_window_size": 8,
15
+ "dtem_t": 1,
16
+ "dtem_feat_dim": null,
17
+ "use_softkmax": false,
18
+ "grid_bias_gamma": 1.0,
19
+ "W_infer": null,
20
+ "qkv_bias": true,
21
+ "qk_norm": false,
22
+ "rope_theta": 10000.0,
23
+ "norm_eps": 1e-6,
24
+ "initializer_range": 0.02,
25
+ "use_cache": true,
26
+ "pad_token_id": 0,
27
+ "bos_token_id": 1,
28
+ "eos_token_id": 2,
29
+ "tie_word_embeddings": false,
30
+ "phase": "phase2",
31
+ "drop_rate": 0.0,
32
+ "attn_drop_rate": 0.0,
33
+ "drop_path_rate": 0.1
34
+ }
gated_deltanet_1b_v3/gated_deltanet_1b_adamw_lr1e_3_b1_0_9_b2_0_99_eps_1e_8_20260503_013622/exp_data/configs/qwen3_next_1B.json ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_type": "qwen3_next",
3
+ "vocab_size": 151936,
4
+ "hidden_size": 2048,
5
+ "intermediate_size": 5632,
6
+ "num_hidden_layers": 48,
7
+ "num_attention_heads": 16,
8
+ "num_key_value_heads": 2,
9
+ "head_dim": 256,
10
+ "hidden_act": "silu",
11
+ "max_position_embeddings": 32768,
12
+ "initializer_range": 0.02,
13
+ "rms_norm_eps": 1e-6,
14
+ "use_cache": true,
15
+ "tie_word_embeddings": false,
16
+ "attention_bias": false,
17
+ "attention_dropout": 0.0,
18
+ "rope_parameters": {
19
+ "rope_type": "default",
20
+ "factor": 1.0
21
+ },
22
+ "partial_rotary_factor": 0.25,
23
+ "layer_types": [
24
+ "linear_attention",
25
+ "linear_attention",
26
+ "linear_attention",
27
+ "full_attention"
28
+ ],
29
+ "linear_conv_kernel_dim": 4,
30
+ "linear_key_head_dim": 128,
31
+ "linear_value_head_dim": 128,
32
+ "linear_num_key_heads": 16,
33
+ "linear_num_value_heads": 32,
34
+ "decoder_sparse_step": 1,
35
+ "moe_intermediate_size": 512,
36
+ "shared_expert_intermediate_size": 512,
37
+ "num_experts_per_tok": 10,
38
+ "num_experts": 512,
39
+ "norm_topk_prob": true,
40
+ "output_router_logits": false,
41
+ "router_aux_loss_coef": 0.001,
42
+ "mlp_only_layers": []
43
+ }
44
+
gated_deltanet_1b_v3/gated_deltanet_1b_adamw_lr1e_3_b1_0_9_b2_0_99_eps_1e_8_20260503_013622/exp_data/configs/qwen3_next_350M.json ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_type": "qwen3_next",
3
+ "vocab_size": 32000,
4
+ "hidden_size": 2048,
5
+ "intermediate_size": 5632,
6
+ "num_hidden_layers": 26,
7
+ "num_attention_heads": 16,
8
+ "num_key_value_heads": 2,
9
+ "head_dim": 256,
10
+ "hidden_act": "silu",
11
+ "max_position_embeddings": 32768,
12
+ "initializer_range": 0.02,
13
+ "rms_norm_eps": 1e-6,
14
+ "use_cache": true,
15
+ "tie_word_embeddings": false,
16
+ "attention_bias": false,
17
+ "attention_dropout": 0.0,
18
+ "rope_parameters": {
19
+ "rope_type": "default",
20
+ "factor": 1.0
21
+ },
22
+ "partial_rotary_factor": 0.25,
23
+ "layer_types": [
24
+ "linear_attention",
25
+ "linear_attention",
26
+ "linear_attention",
27
+ "full_attention"
28
+ ],
29
+ "linear_conv_kernel_dim": 4,
30
+ "linear_key_head_dim": 128,
31
+ "linear_value_head_dim": 128,
32
+ "linear_num_key_heads": 16,
33
+ "linear_num_value_heads": 32,
34
+ "decoder_sparse_step": 1,
35
+ "moe_intermediate_size": 512,
36
+ "shared_expert_intermediate_size": 512,
37
+ "num_experts_per_tok": 10,
38
+ "num_experts": 512,
39
+ "norm_topk_prob": true,
40
+ "output_router_logits": false,
41
+ "router_aux_loss_coef": 0.001,
42
+ "mlp_only_layers": []
43
+ }
44
+
gated_deltanet_1b_v3/gated_deltanet_1b_adamw_lr1e_3_b1_0_9_b2_0_99_eps_1e_8_20260503_013622/exp_data/configs/transformer_1B.json ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token_id": 1,
3
+ "elementwise_affine": true,
4
+ "eos_token_id": 2,
5
+ "fuse_cross_entropy": true,
6
+ "fuse_norm": true,
7
+ "fuse_swiglu": true,
8
+ "hidden_act": "swish",
9
+ "hidden_ratio": 4,
10
+ "hidden_size": 2048,
11
+ "initializer_range": 0.02,
12
+ "intermediate_size": null,
13
+ "max_position_embeddings": 8192,
14
+ "model_type": "transformer",
15
+ "norm_eps": 1e-06,
16
+ "num_heads": 32,
17
+ "num_hidden_layers": 24,
18
+ "num_kv_heads": null,
19
+ "pad_token_id": 2,
20
+ "rope_theta": 10000.0,
21
+ "tie_word_embeddings": false
22
+ }
gated_deltanet_1b_v3/gated_deltanet_1b_adamw_lr1e_3_b1_0_9_b2_0_99_eps_1e_8_20260503_013622/exp_data/configs/transformer_340M.json ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attention_bias": false,
3
+ "bos_token_id": 1,
4
+ "eos_token_id": 2,
5
+ "fuse_cross_entropy": true,
6
+ "fuse_norm": true,
7
+ "hidden_act": "swish",
8
+ "hidden_size": 1024,
9
+ "initializer_range": 0.02,
10
+ "max_position_embeddings": 8192,
11
+ "model_type": "transformer",
12
+ "num_heads": 16,
13
+ "num_hidden_layers": 24,
14
+ "norm_eps": 1e-06,
15
+ "tie_word_embeddings": false,
16
+ "use_cache": true,
17
+ "vocab_size": 32000
18
+ }
gated_deltanet_1b_v3/gated_deltanet_1b_adamw_lr1e_3_b1_0_9_b2_0_99_eps_1e_8_20260503_013622/exp_data/configs/transformer_7B.json ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attention_bias": false,
3
+ "bos_token_id": 1,
4
+ "eos_token_id": 2,
5
+ "fuse_cross_entropy": true,
6
+ "fuse_norm": true,
7
+ "hidden_act": "swish",
8
+ "hidden_ratio": 4,
9
+ "hidden_size": 4096,
10
+ "initializer_range": 0.02,
11
+ "intermediate_size": 14336,
12
+ "model_type": "transformer",
13
+ "norm_eps": 1e-06,
14
+ "num_heads": 32,
15
+ "num_hidden_layers": 32,
16
+ "num_kv_heads": 8,
17
+ "rope_theta": 10000.0,
18
+ "tie_word_embeddings": false,
19
+ "use_cache": true,
20
+ "window_size": null
21
+ }
gated_deltanet_1b_v3/gated_deltanet_1b_adamw_lr1e_3_b1_0_9_b2_0_99_eps_1e_8_20260503_013622/exp_data/flame/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ __version__ = "0.1.0"
gated_deltanet_1b_v3/gated_deltanet_1b_adamw_lr1e_3_b1_0_9_b2_0_99_eps_1e_8_20260503_013622/exp_data/flame/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (207 Bytes). View file
 
gated_deltanet_1b_v3/gated_deltanet_1b_adamw_lr1e_3_b1_0_9_b2_0_99_eps_1e_8_20260503_013622/exp_data/flame/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (244 Bytes). View file
 
gated_deltanet_1b_v3/gated_deltanet_1b_adamw_lr1e_3_b1_0_9_b2_0_99_eps_1e_8_20260503_013622/exp_data/flame/__pycache__/__init__.cpython-313.pyc ADDED
Binary file (238 Bytes). View file
 
gated_deltanet_1b_v3/gated_deltanet_1b_adamw_lr1e_3_b1_0_9_b2_0_99_eps_1e_8_20260503_013622/exp_data/flame/__pycache__/config_manager.cpython-310.pyc ADDED
Binary file (29.6 kB). View file
 
gated_deltanet_1b_v3/gated_deltanet_1b_adamw_lr1e_3_b1_0_9_b2_0_99_eps_1e_8_20260503_013622/exp_data/flame/__pycache__/config_manager.cpython-311.pyc ADDED
Binary file (41.5 kB). View file
 
gated_deltanet_1b_v3/gated_deltanet_1b_adamw_lr1e_3_b1_0_9_b2_0_99_eps_1e_8_20260503_013622/exp_data/flame/__pycache__/data.cpython-310.pyc ADDED
Binary file (21.7 kB). View file
 
gated_deltanet_1b_v3/gated_deltanet_1b_adamw_lr1e_3_b1_0_9_b2_0_99_eps_1e_8_20260503_013622/exp_data/flame/__pycache__/data.cpython-311.pyc ADDED
Binary file (41.6 kB). View file
 
gated_deltanet_1b_v3/gated_deltanet_1b_adamw_lr1e_3_b1_0_9_b2_0_99_eps_1e_8_20260503_013622/exp_data/flame/__pycache__/train.cpython-310.pyc ADDED
Binary file (19.6 kB). View file
 
gated_deltanet_1b_v3/gated_deltanet_1b_adamw_lr1e_3_b1_0_9_b2_0_99_eps_1e_8_20260503_013622/exp_data/flame/__pycache__/train.cpython-311.pyc ADDED
Binary file (41.2 kB). View file
 
gated_deltanet_1b_v3/gated_deltanet_1b_adamw_lr1e_3_b1_0_9_b2_0_99_eps_1e_8_20260503_013622/exp_data/flame/__pycache__/train.cpython-313.pyc ADDED
Binary file (39.7 kB). View file
 
gated_deltanet_1b_v3/gated_deltanet_1b_adamw_lr1e_3_b1_0_9_b2_0_99_eps_1e_8_20260503_013622/exp_data/flame/c4_test.py ADDED
@@ -0,0 +1,603 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ import os
7
+ import time
8
+ import json
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.distributed as dist
13
+
14
+ from tqdm import tqdm
15
+ from loguru import logger
16
+
17
+ import transformers
18
+
19
+ transformers.logging.set_verbosity_error()
20
+
21
+ import wandb
22
+
23
+ from utils.argparse import parse_args
24
+ from utils.setup import getting_svd_cnt, set_seed, setup_model, saving_model_weight, load_model_weight
25
+ from utils.optimizer_factory import setup_optimization
26
+ from utils.eval import evaluate_model
27
+ from utils.dataloader import setup_dataset
28
+ from utils.modeling_llama import LlamaForCausalLM
29
+ from utils.fake_quantization import QLinear
30
+ from utils.quantization import QScaleLinear
31
+
32
+
33
+ def main(args):
34
+ import torch
35
+ ############ Setup random seed ############
36
+ set_seed(args)
37
+
38
+ ############ Setup DDP environment ############
39
+ assert "LOCAL_RANK" in os.environ, "torchrun should set LOCAL_RANK"
40
+ global_rank = int(os.environ["RANK"])
41
+ local_rank = int(os.environ["LOCAL_RANK"])
42
+ world_size = int(os.environ["WORLD_SIZE"])
43
+ torch.cuda.set_device(local_rank)
44
+
45
+ logger.info(f"Global rank {global_rank}, local rank {local_rank}, device: {torch.cuda.current_device()}")
46
+ dist.init_process_group(backend="nccl", rank=global_rank, world_size=world_size)
47
+
48
+ logger.info("Process group initialized")
49
+ device = f"cuda:{local_rank}"
50
+
51
+ if global_rank != 0:
52
+ logger.remove() # turn off logger
53
+
54
+ logger.info(f"Using dist with rank {global_rank} (only rank 0 will log)")
55
+ logger.info("*" * 40)
56
+ logger.info(f"Starting training with the arguments")
57
+ for k, v in vars(args).items():
58
+ logger.info(f"{k:30} {v}")
59
+ logger.info("*" * 40)
60
+
61
+ ############ Initialize wandb without config (it is passed later) ############
62
+ if (not args.unset_wandb) and global_rank == 0:
63
+ if args.entity is None:
64
+ os.environ['WANDB_MODE'] = 'offline'
65
+ # Set wandb directory for offline mode
66
+ wandb_dir = getattr(args, 'wandb_dir', None) if getattr(args, 'wandb_dir', None) is not None else args.save_dir
67
+ if getattr(args, 'wandb_dir', None) is not None:
68
+ logger.info(f"Wandb directory set to: {wandb_dir}")
69
+ wandb.init(project=args.project, name=args.name, entity=args.entity, dir=wandb_dir)
70
+
71
+ ############ Setup training data ############
72
+ if args.total_batch_size is not None:
73
+ if args.gradient_accumulation is None:
74
+ assert args.total_batch_size % world_size == 0, "total_batch_size must be divisible by world_size"
75
+ args.gradient_accumulation = args.total_batch_size // (args.batch_size * world_size)
76
+ assert args.gradient_accumulation > 0, "gradient_accumulation must be greater than 0"
77
+
78
+ assert (
79
+ args.gradient_accumulation * args.batch_size * world_size == args.total_batch_size
80
+ ), "gradient_accumulation * batch_size * world_size must be equal to total_batch_size"
81
+
82
+ dataloader, tokenizer = setup_dataset(args, global_rank, world_size)
83
+
84
+ ############ Initialize model ############
85
+ model_config, model = setup_model(args)
86
+ # Ensure model has generation_config (fix for transformers version compatibility)
87
+ if model.generation_config is None:
88
+ from transformers import GenerationConfig
89
+ model.generation_config = GenerationConfig()
90
+ model.generation_config.pad_token_id = tokenizer.pad_token_id
91
+
92
+ ############ Resuming from checkpoints ############
93
+ global_step = 0
94
+ update_step = 0
95
+ beginning_step = 0
96
+ tokens_seen = 0
97
+ tokens_seen_before = 0
98
+
99
+ # identifying checkpointing
100
+ if args.continue_from is not None and os.path.exists(args.continue_from):
101
+ # searching the latest checkpoints
102
+ checkpoint_path_list = os.listdir(args.continue_from)
103
+ checkpoint_path_list = [int(x.split("_")[-1]) for x in checkpoint_path_list if x.startswith("model_")]
104
+ if len(checkpoint_path_list) > 0:
105
+ logger.info("Find Checkpoints", checkpoint_path_list)
106
+ beginning_step = max(checkpoint_path_list)
107
+ if args.resume_step is not None:
108
+ beginning_step = args.resume_step
109
+ args.continue_from = os.path.join(args.continue_from, f"model_{beginning_step}")
110
+ logger.info("Continue from", args.continue_from)
111
+ else:
112
+ logger.warning(f"Did not find any checkpoints in {args.continue_from}")
113
+ args.continue_from = None
114
+
115
+ # resuming from checkpointing
116
+ if args.continue_from is not None:
117
+ logger.info("*" * 40)
118
+ logger.info(f"Loading model from {args.continue_from}")
119
+ checkpoint_path = os.path.join(args.continue_from, "pytorch_model.bin")
120
+ if os.path.exists(checkpoint_path):
121
+ load_model_weight(model, checkpoint_path, args)
122
+ logger.info(f"Model successfully loaded (strict=False policy)")
123
+ else:
124
+ # Try safetensors format
125
+ checkpoint_path = os.path.join(args.continue_from, "model.safetensors")
126
+ if os.path.exists(checkpoint_path):
127
+ from safetensors import safe_open
128
+ tensors = {}
129
+ with safe_open(checkpoint_path, framework="pt", device=0) as f:
130
+ for k in f.keys():
131
+ tensors[k] = f.get_tensor(k)
132
+ print(k, tensors[k].shape)
133
+ ret = model.load_state_dict(tensors, strict=False)
134
+ logger.info(f"Model successfully loaded from safetensors (strict=False policy)", ret)
135
+ else:
136
+ logger.warning(f"No model checkpoint found in {args.continue_from}")
137
+
138
+ if os.path.exists(os.path.join(args.continue_from, "training_state.json")):
139
+ logger.info(
140
+ f"Loading training state like global_step, update_step, and tokens_seen from {args.continue_from}"
141
+ )
142
+ with open(os.path.join(args.continue_from, "training_state.json")) as f:
143
+ _old_state = json.load(f)
144
+ global_step = _old_state["global_step"]
145
+ update_step = _old_state["update_step"]
146
+ tokens_seen = _old_state["tokens_seen"]
147
+ tokens_seen_before = _old_state["tokens_seen_before"]
148
+ logger.info(f"global_step : {global_step}")
149
+ logger.info(f"update_step : {update_step}")
150
+ logger.info(f"tokens_seen : {tokens_seen}")
151
+ logger.info(f"tokens_seen_before: {tokens_seen_before}")
152
+ logger.info(f"Will train for {args.num_training_steps - update_step} update steps")
153
+ else:
154
+ logger.warning(f"Did not find training state in {args.continue_from}, global step will start from zero")
155
+ logger.info("*" * 40)
156
+
157
+ ############ Setup model ############
158
+ if args.dtype in ["bf16", "bfloat16"]:
159
+ model = model.to(dtype=torch.bfloat16)
160
+ model = model.to(device=device)
161
+
162
+ for _, module in model.named_modules():
163
+ if isinstance(module, QScaleLinear):
164
+ weight_device = module.weight.device
165
+ module.weight.scales = module.weight.scales.to(device=weight_device)
166
+ module.weight.zeros = module.weight.zeros.to(device=weight_device)
167
+
168
+ n_total_params = sum(p.numel() for p in model.parameters())
169
+ trainable_params = [p for p in model.parameters() if p.requires_grad]
170
+ trainable_params_int8 = [p for p in model.parameters() if hasattr(p, "group_size")]
171
+
172
+ ############ Initialize wandb ############
173
+ run_config = dict(vars(args))
174
+ run_config.update(
175
+ {
176
+ "max_lr": run_config.pop("lr"), # rename lr to max_lr to avoid conflicts with scheduler
177
+ "total_params_M": n_total_params / 1_000_000,
178
+ "dataset": "c4",
179
+ "model": model_config.to_dict(),
180
+ "world_size": world_size,
181
+ "device": str(device),
182
+ }
183
+ )
184
+
185
+ if global_rank == 0:
186
+ if not args.unset_wandb:
187
+ wandb.config.update(run_config, allow_val_change=True)
188
+ wandb.save(os.path.abspath(__file__), policy="now") # save current script
189
+ # fix tqdm visual length to 80 so that the progress bar
190
+ # doesn't jump around when changing from external display to laptop
191
+ pbar = tqdm(total=args.num_training_steps - update_step, desc="Update steps", ncols=80)
192
+
193
+ ############ Initialize optimization ############
194
+ if "galore" in args.optimizer.lower():
195
+ # make parameters with "rank" to a single group, if param_name has "mlp" or "attn"
196
+ lowrank_params = []
197
+ target_modules_list = ["attn", "mlp"]
198
+ for module_name, module in model.named_modules():
199
+ if not (isinstance(module, nn.Linear) or isinstance(module, QScaleLinear) or isinstance(module, QLinear)):
200
+ continue
201
+ if not any(target_key in module_name for target_key in target_modules_list):
202
+ continue
203
+ logger.info(f"Adding {module_name} to GaLore parameters")
204
+ lowrank_params.append(module.weight)
205
+
206
+ id_lowrank_params = [id(p) for p in lowrank_params]
207
+ # make parameters without "rank" to another group
208
+ regular_params = [p for p in model.parameters() if id(p) not in id_lowrank_params]
209
+ # then call low rank optimizer
210
+ param_groups = [
211
+ {"params": regular_params},
212
+ {
213
+ "params": lowrank_params,
214
+ "rank": args.rank,
215
+ "update_proj_gap": args.update_proj_gap,
216
+ "scale": args.galore_scale,
217
+ "proj_type": args.proj_type,
218
+ "quant": args.proj_quant,
219
+ "quant_n_bit": args.proj_bits,
220
+ "quant_group_size": args.proj_group_size,
221
+ "cos_threshold": args.cos_threshold,
222
+ "gamma_proj": args.gamma_proj,
223
+ "queue_size": args.queue_size,
224
+ },
225
+ ]
226
+ elif "apollo" in args.optimizer.lower():
227
+ # make parameters with "rank" to a single group, if param_name has "mlp" or "attn"
228
+ lowrank_params = []
229
+ target_modules_list = ["attn", "mlp"]
230
+ for module_name, module in model.named_modules():
231
+ if not (isinstance(module, nn.Linear) or isinstance(module, QScaleLinear) or isinstance(module, QLinear)):
232
+ continue
233
+ if not any(target_key in module_name for target_key in target_modules_list):
234
+ continue
235
+ logger.info(f"Adding {module_name} to APOLLO parameters")
236
+ lowrank_params.append(module.weight)
237
+
238
+ id_lowrank_params = [id(p) for p in lowrank_params]
239
+ # make parameters without "rank" to another group
240
+ regular_params = [p for p in model.parameters() if id(p) not in id_lowrank_params]
241
+ # then call low rank optimizer
242
+ param_groups = [
243
+ {"params": regular_params},
244
+ {
245
+ "params": lowrank_params,
246
+ "rank": args.rank,
247
+ "update_proj_gap": args.update_proj_gap,
248
+ "scale": args.apollo_scale,
249
+ "proj_type": args.proj_type,
250
+ "proj": args.proj,
251
+ "scale_type": args.scale_type,
252
+ },
253
+ ]
254
+ elif "conda" in args.optimizer.lower():
255
+ # make parameters with "rank" to a single group, if param_name has "mlp" or "attn"
256
+ lowrank_params = []
257
+ target_modules_list = ["attn", "mlp"]
258
+ for module_name, module in model.named_modules():
259
+ if not (isinstance(module, nn.Linear) or isinstance(module, QScaleLinear) or isinstance(module, QLinear)):
260
+ continue
261
+ if not any(target_key in module_name for target_key in target_modules_list):
262
+ continue
263
+ logger.info(f"Adding {module_name} to conda parameters")
264
+ lowrank_params.append(module.weight)
265
+
266
+ id_lowrank_params = [id(p) for p in lowrank_params]
267
+ # make parameters without "rank" to another group
268
+ regular_params = [p for p in model.parameters() if id(p) not in id_lowrank_params]
269
+ # then call low rank optimizer
270
+ param_groups = [
271
+ {"params": regular_params},
272
+ {
273
+ "params": lowrank_params,
274
+ "rank": args.rank,
275
+ "update_proj_gap": args.update_proj_gap,
276
+ "scale": args.apollo_scale,
277
+ "proj_type": args.proj_type,
278
+ "proj": args.proj,
279
+ "scale_type": args.scale_type,
280
+ },
281
+ ]
282
+ else:
283
+ param_groups = None
284
+ id_lowrank_params = None
285
+
286
+ # print params and trainable params
287
+ logger.info(f"\n{model}\n")
288
+ logger.info(f"Total params: {sum(p.numel() for p in model.parameters()) / 1_000_000:.2f}M")
289
+
290
+ if args.simulation:
291
+ num_train_params = sum(p.numel() for p in trainable_params)
292
+ else:
293
+ num_train_params = sum(p.numel() for p in trainable_params) + sum(p.numel() for p in trainable_params_int8)
294
+
295
+ logger.info(f"Trainable params: {num_train_params / 1_000_000:.2f}M")
296
+ if "q_galore" in args.optimizer.lower():
297
+ logger.info(
298
+ f"Trainable params with Q-GaLore enabled: {sum(p.numel() for p in trainable_params_int8) / 1_000_000:.2f}M"
299
+ )
300
+ elif "galore" in args.optimizer.lower():
301
+ logger.info(f"Total params with GaLore enabled: {sum(p.numel() for p in lowrank_params) / 1_000_000:.2f}M")
302
+ elif "q_apollo" in args.optimizer.lower():
303
+ logger.info(
304
+ f"Trainable params with Q-APOLLO enabled: {sum(p.numel() for p in trainable_params_int8) / 1_000_000:.2f}M"
305
+ )
306
+ elif "apollo" in args.optimizer.lower():
307
+ logger.info(f"Total params with APOLLO enabled: {sum(p.numel() for p in lowrank_params) / 1_000_000:.2f}M")
308
+
309
+ logger.info(f"Saving model to {args.save_dir} every {args.save_every} update steps")
310
+
311
+ model, optimizer, scheduler, layer_wise_flag = setup_optimization(
312
+ args, model, trainable_params, param_groups, id_lowrank_params, model_config
313
+ )
314
+
315
+ if layer_wise_flag:
316
+ # will pass optimizer_dict and scheduler_dict out instead of optimizer and scheduler
317
+ optimizer_dict = optimizer
318
+ scheduler_dict = scheduler
319
+
320
+ # Bug-3 fix: wrap with DDP *before* torch.compile per PyTorch recommendation.
321
+ # This ensures gradient reduction hooks are correctly installed on the DDP module,
322
+ # and the compiled graph captures the full DDP+model forward pass.
323
+ # (Issue-5: optimizer.load_state_dict is called after both DDP and compile below.)
324
+ if not args.single_gpu:
325
+ model: LlamaForCausalLM = torch.nn.parallel.DistributedDataParallel(
326
+ model,
327
+ device_ids=[local_rank],
328
+ output_device=local_rank,
329
+ broadcast_buffers=False,
330
+ )
331
+
332
+ # compile the model (after DDP so the compiled graph includes DDP reduction)
333
+ if args.compile:
334
+ print("Compiling the model... (takes a ~minute)")
335
+ unoptimized_model = model
336
+
337
+ # Configure TorchDynamo to suppress errors and fall back to eager mode
338
+ import torch._dynamo
339
+ torch._dynamo.config.suppress_errors = args.dynamo_suppress_errors
340
+ torch._dynamo.config.verbose = False
341
+ # Set cache size limit to prevent memory issues during long training
342
+ torch._dynamo.config.cache_size_limit = args.dynamo_cache_limit
343
+
344
+ model = torch.compile(model) # requires PyTorch 2.0
345
+
346
+ # resume optimizer
347
+ if args.restore_optimizer and args.continue_from is not None:
348
+ logger.info("Restoring optimizer and scheduler from the checkpoint")
349
+ _optimizer_dir = args.continue_from
350
+ optimizer_checkpoint = torch.load(os.path.join(_optimizer_dir, "optimizer.pt"), map_location="cpu")
351
+ optimizer.load_state_dict(optimizer_checkpoint["optimizer"])
352
+ scheduler.load_state_dict(optimizer_checkpoint["scheduler"])
353
+ update_step = optimizer_checkpoint["update_step"]
354
+ beginning_step = update_step
355
+ global_step = optimizer_checkpoint["global_step"]
356
+ logger.info(f"Optimizer and scheduler restored from {_optimizer_dir}")
357
+
358
+ # ##############################
359
+ # TRAINING LOOP
360
+ # we use iterable dataset, so we may never go through all the data
361
+ # ##############################
362
+ # global steps and others are defined above
363
+ pad_idx = tokenizer.pad_token_id
364
+ update_time = time.time()
365
+ local_step = 0 # when continue_from is used, local_step != global_step
366
+ total_svd_count = 0
367
+
368
+ dataloader_iter = iter(dataloader)
369
+
370
+ # Issue-4 fix: accumulate loss across micro-batches so logged loss is the true
371
+ # gradient-accumulation average, not just the last micro-batch.
372
+ accumulated_loss = 0.0
373
+
374
+ # Skip data if resuming from checkpoint
375
+ if update_step != 0:
376
+ skip_batches = args.gradient_accumulation * update_step
377
+ logger.info(f"Skipping {skip_batches} batches to resume from update step {update_step}")
378
+ skipped = 0
379
+ for _ in range(skip_batches):
380
+ # Issue-6 fix: handle StopIteration during skip so all ranks stay aligned
381
+ try:
382
+ next(dataloader_iter)
383
+ except StopIteration:
384
+ logger.warning(
385
+ f"Dataset exhausted during skip at batch {skipped}/{skip_batches}; "
386
+ f"restarting iterator to keep ranks aligned."
387
+ )
388
+ dataloader_iter = iter(dataloader)
389
+ next(dataloader_iter)
390
+ skipped += 1
391
+ logger.info(f"Skipped {skipped} batches successfully")
392
+
393
+ while update_step <= args.num_training_steps:
394
+ try:
395
+ batch = next(dataloader_iter)
396
+ except StopIteration:
397
+ logger.info(f"Dataset completed one epoch. Starting new epoch with reshuffled data.")
398
+ dataloader_iter = iter(dataloader)
399
+ batch = next(dataloader_iter)
400
+
401
+ global_step += 1
402
+ local_step += 1
403
+
404
+ if update_step >= args.num_training_steps:
405
+ logger.info(f"Reached max number of update steps ({args.num_training_steps}). Stopping training.")
406
+ logger.info(f"Rank {global_rank} stopping training.")
407
+ break
408
+
409
+ # forward & backward
410
+ batch = {k: v.to(device) for k, v in batch.items()}
411
+ labels = batch["input_ids"].clone()
412
+ labels[labels == pad_idx] = -100
413
+ tokens_seen += (batch["input_ids"] != pad_idx).sum().item() * world_size
414
+
415
+ loss = model(**batch, labels=labels).loss
416
+
417
+ scaled_loss = loss / args.gradient_accumulation
418
+ scaled_loss.backward()
419
+ accumulated_loss += loss.item() # Issue-4: accumulate before the continue
420
+
421
+ if global_step % args.gradient_accumulation != 0:
422
+ continue
423
+
424
+ # The below code is only executed during the update step
425
+ # Issue-4: compute average loss over all micro-batches in this accumulation window
426
+ avg_loss = accumulated_loss / args.gradient_accumulation
427
+ accumulated_loss = 0.0 # reset for next accumulation window
428
+ # add grad clipping: TODO: add gradient clipping of int8 weight
429
+ if args.grad_clipping != 0.0:
430
+ torch.nn.utils.clip_grad_norm_(trainable_params, args.grad_clipping)
431
+ # Periodic memory cleanup to prevent symbolic tensor issues during long training
432
+ if global_step % args.memory_cleanup_frequency == 0:
433
+ torch.cuda.empty_cache()
434
+ # Clear TorchDynamo cache to prevent memory accumulation
435
+ if args.compile:
436
+ import torch._dynamo
437
+ torch._dynamo.reset()
438
+
439
+ if global_rank == 0:
440
+ pbar.update(1)
441
+ if not layer_wise_flag: # layer-wise updation is done during backward; requires gradient_accumulation equals 1
442
+ optimizer.step()
443
+ scheduler.step()
444
+ optimizer.zero_grad()
445
+
446
+ update_step += 1
447
+ update_time = time.time() - update_time
448
+
449
+ # save checkpoint by save_every
450
+ if local_step > args.gradient_accumulation and update_step % args.save_every == 0 and global_rank == 0:
451
+ current_model_directory = f"{args.save_dir}/model_{update_step}"
452
+ logger.info(f"Saving model and optimizer to {current_model_directory}, update step {update_step}")
453
+ os.makedirs(args.save_dir, exist_ok=True)
454
+ # Bug-1 fix: unwrap DDP/compiled model for saving; works in both single-GPU and multi-GPU modes
455
+ unwrapped_model = model.module if hasattr(model, 'module') else model
456
+ unwrapped_model.save_pretrained(current_model_directory, max_shard_size="500GB", from_pt=True)
457
+ saving_model_weight(unwrapped_model, f"{current_model_directory}/pytorch_model.bin", args)
458
+
459
+ optimizer_checkpoint = {
460
+ "optimizer": optimizer.state_dict(),
461
+ "scheduler": scheduler.state_dict(),
462
+ "update_step": update_step,
463
+ "global_step": global_step,
464
+ "config": run_config,
465
+ "wandb": wandb.run.dir if not args.unset_wandb else None,
466
+ "dtype": args.dtype,
467
+ }
468
+ torch.save(optimizer_checkpoint, f"{current_model_directory}/optimizer.pt")
469
+
470
+ training_state_checkpoint = {
471
+ "global_step": global_step,
472
+ "update_step": update_step,
473
+ "tokens_seen": tokens_seen,
474
+ "tokens_seen_before": tokens_seen_before,
475
+ "update_time": update_time,
476
+ }
477
+ with open(f"{current_model_directory}/training_state.json", "w") as f:
478
+ json.dump(training_state_checkpoint, f, indent=4)
479
+
480
+ # save wandb related info
481
+ if not args.unset_wandb:
482
+ wandb_info = {
483
+ "wandb_id": wandb.run.id,
484
+ }
485
+ with open(f"{args.save_dir}/wandb.json", "w") as f:
486
+ json.dump(wandb_info, f, indent=4)
487
+
488
+ # evaluation
489
+ if update_step % args.eval_every == 0:
490
+ logger.info(f"Performing evaluation at step {update_step}")
491
+ total_loss, evaluated_on_tokens, perplexity = evaluate_model(
492
+ model, tokenizer, pad_idx, global_rank, world_size, device, args
493
+ )
494
+
495
+ if global_rank == 0:
496
+ if not args.unset_wandb:
497
+ wandb.log(
498
+ {
499
+ "eval_loss": total_loss,
500
+ "eval_perplexity": perplexity,
501
+ "eval_tokens": evaluated_on_tokens,
502
+ },
503
+ step=update_step,
504
+ )
505
+ logger.info(f"Eval loss at step {update_step}: {total_loss}, Eval perplexity: {perplexity}")
506
+
507
+ if not layer_wise_flag:
508
+ lr = optimizer.param_groups[0]["lr"]
509
+ else:
510
+ lr = list(optimizer_dict.values())[0].param_groups[0]["lr"]
511
+ tokens_in_update = tokens_seen - tokens_seen_before
512
+ tokens_seen_before = tokens_seen
513
+ batches_in_update = args.gradient_accumulation * world_size
514
+ if not layer_wise_flag:
515
+ total_svd_count = getting_svd_cnt(optimizer)
516
+ else:
517
+ total_svd_count = 0
518
+
519
+ if global_rank == 0:
520
+ if not args.unset_wandb:
521
+ wandb.log(
522
+ {
523
+ "loss": avg_loss,
524
+ "lr": lr,
525
+ "update_step": update_step,
526
+ "tokens_seen": tokens_seen,
527
+ "total_svd_count": total_svd_count,
528
+ "throughput_tokens": tokens_in_update / update_time,
529
+ "throughput_examples": args.total_batch_size / update_time,
530
+ "throughput_batches": batches_in_update / update_time,
531
+ },
532
+ step=update_step,
533
+ )
534
+ update_time = time.time()
535
+
536
+ # ##############################
537
+ # END of training loop
538
+ # ##############################
539
+ logger.info("Training finished")
540
+ if global_rank == 0:
541
+ pbar.close()
542
+
543
+ current_model_directory = f"{args.save_dir}/model_{update_step}"
544
+ if global_rank == 0 and not os.path.exists(current_model_directory):
545
+ logger.info(f"Saving model and optimizer to {current_model_directory}, update step {update_step}")
546
+ os.makedirs(args.save_dir, exist_ok=True)
547
+ # Bug-1 fix: unwrap DDP/compiled model for saving; works in both single-GPU and multi-GPU modes
548
+ unwrapped_model = model.module if hasattr(model, 'module') else model
549
+ unwrapped_model.save_pretrained(current_model_directory, max_shard_size="500GB", from_pt=True)
550
+ saving_model_weight(unwrapped_model, f"{current_model_directory}/pytorch_model.bin", args)
551
+
552
+ optimizer_checkpoint = {
553
+ "optimizer": optimizer.state_dict(),
554
+ "scheduler": scheduler.state_dict(),
555
+ "update_step": update_step,
556
+ "global_step": global_step,
557
+ "config": run_config,
558
+ "wandb": wandb.run.dir if not args.unset_wandb else None,
559
+ "dtype": args.dtype,
560
+ }
561
+ torch.save(optimizer_checkpoint, f"{current_model_directory}/optimizer.pt")
562
+
563
+ training_state_checkpoint = {
564
+ "global_step": global_step,
565
+ "update_step": update_step,
566
+ "tokens_seen": tokens_seen,
567
+ "tokens_seen_before": tokens_seen_before,
568
+ "update_time": update_time,
569
+ }
570
+ with open(f"{current_model_directory}/training_state.json", "w") as f:
571
+ json.dump(training_state_checkpoint, f, indent=4)
572
+
573
+ # Final evaluation
574
+ logger.info("Running final evaluation")
575
+ model.eval()
576
+ del loss, optimizer, scheduler
577
+ import gc
578
+
579
+ gc.collect()
580
+ torch.cuda.empty_cache()
581
+
582
+ total_loss, evaluated_on_tokens, perplexity = evaluate_model(model, tokenizer, pad_idx, global_rank, world_size, device, args)
583
+
584
+ if global_rank == 0:
585
+ if not args.unset_wandb:
586
+ wandb.log(
587
+ {
588
+ "final_eval_loss": total_loss,
589
+ "final_eval_perplexity": perplexity,
590
+ "final_eval_tokens": evaluated_on_tokens,
591
+ },
592
+ step=update_step,
593
+ )
594
+ logger.info(f"Final eval loss: {total_loss}, Final eval perplexity: {perplexity}")
595
+
596
+ logger.info("Script finished successfully")
597
+ print(f"Rank {global_rank} finished successfully")
598
+
599
+
600
+ if __name__ == "__main__":
601
+ print("Starting script")
602
+ args = parse_args(None)
603
+ main(args)
gated_deltanet_1b_v3/gated_deltanet_1b_adamw_lr1e_3_b1_0_9_b2_0_99_eps_1e_8_20260503_013622/exp_data/flame/components/__init__.py ADDED
File without changes
gated_deltanet_1b_v3/gated_deltanet_1b_adamw_lr1e_3_b1_0_9_b2_0_99_eps_1e_8_20260503_013622/exp_data/flame/components/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (197 Bytes). View file
 
gated_deltanet_1b_v3/gated_deltanet_1b_adamw_lr1e_3_b1_0_9_b2_0_99_eps_1e_8_20260503_013622/exp_data/flame/components/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (235 Bytes). View file
 
gated_deltanet_1b_v3/gated_deltanet_1b_adamw_lr1e_3_b1_0_9_b2_0_99_eps_1e_8_20260503_013622/exp_data/flame/components/__pycache__/checkpoint.cpython-310.pyc ADDED
Binary file (1.92 kB). View file
 
gated_deltanet_1b_v3/gated_deltanet_1b_adamw_lr1e_3_b1_0_9_b2_0_99_eps_1e_8_20260503_013622/exp_data/flame/components/__pycache__/checkpoint.cpython-311.pyc ADDED
Binary file (3.7 kB). View file
 
gated_deltanet_1b_v3/gated_deltanet_1b_adamw_lr1e_3_b1_0_9_b2_0_99_eps_1e_8_20260503_013622/exp_data/flame/components/checkpoint.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from dataclasses import dataclass, field
8
+ from datetime import timedelta
9
+ from io import BytesIO
10
+ from typing import Any, Dict, List
11
+
12
+ import torch
13
+ from torch.distributed.checkpoint.stateful import Stateful
14
+
15
+
16
+ @dataclass
17
+ class TrainState(Stateful):
18
+ step: int = 0
19
+ skipped_step: int = 0
20
+ token: int = 0
21
+ elapsed: timedelta = timedelta(0)
22
+ global_avg_losses: List[float] = field(default_factory=list)
23
+ global_max_losses: List[float] = field(default_factory=list)
24
+ log_steps: List[int] = field(default_factory=list)
25
+
26
+ def state_dict(self) -> Dict[str, Any]:
27
+ # Only checkpoint global_avg_losses and global_max_losses per log frequency
28
+ # to avoid sync overhead in every iteration.
29
+ global_avg_losses_bytes = BytesIO()
30
+ torch.save(self.global_avg_losses, global_avg_losses_bytes)
31
+ global_max_losses_bytes = BytesIO()
32
+ torch.save(self.global_max_losses, global_max_losses_bytes)
33
+ log_steps_bytes = BytesIO()
34
+ torch.save(self.log_steps, log_steps_bytes)
35
+ return {
36
+ "step": torch.tensor(self.step, dtype=torch.int32),
37
+ "skipped_step": torch.tensor(self.skipped_step, dtype=torch.int32),
38
+ "token": torch.tensor(self.token, dtype=torch.int64),
39
+ "elapsed": self.elapsed,
40
+ "global_avg_losses": global_avg_losses_bytes,
41
+ "global_max_losses": global_max_losses_bytes,
42
+ "log_steps": log_steps_bytes,
43
+ }
44
+
45
+ def load_state_dict(self, state_dict) -> None:
46
+ self.step = state_dict["step"].item()
47
+ self.skipped_step = state_dict.get("skipped_step", 0).item()
48
+ self.token = state_dict["token"].item()
49
+ self.elapsed = state_dict["elapsed"]
50
+ state_dict["global_avg_losses"].seek(0)
51
+ self.global_avg_losses = torch.load(
52
+ state_dict["global_avg_losses"], weights_only=False
53
+ )
54
+ state_dict["global_max_losses"].seek(0)
55
+ self.global_max_losses = torch.load(
56
+ state_dict["global_max_losses"], weights_only=False
57
+ )
58
+ state_dict["log_steps"].seek(0)
59
+ self.log_steps = torch.load(state_dict["log_steps"], weights_only=False)
gated_deltanet_1b_v3/gated_deltanet_1b_adamw_lr1e_3_b1_0_9_b2_0_99_eps_1e_8_20260503_013622/exp_data/flame/config_manager.py ADDED
@@ -0,0 +1,981 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import argparse
8
+ import sys
9
+ from collections import defaultdict
10
+ from typing import Tuple
11
+
12
+ import torch
13
+
14
+ try:
15
+ import tomllib
16
+ except ModuleNotFoundError:
17
+ import tomli as tomllib
18
+
19
+ from torchtitan.tools.logging import logger
20
+
21
+ TORCH_DTYPE_MAP = {
22
+ "float16": torch.float16,
23
+ "float32": torch.float32,
24
+ "bfloat16": torch.bfloat16,
25
+ }
26
+
27
+
28
+ def string_list(raw_arg):
29
+ """Comma-separated string list argument."""
30
+ return [s.strip() for s in raw_arg.split(",") if s.strip()]
31
+
32
+
33
+ def check_string_list_argument(args_dict: dict[str, any], fullargname: str):
34
+ section, name = fullargname.split(".")
35
+ # Split string list which are still raw strings.
36
+ if (
37
+ section in args_dict
38
+ and name in args_dict[section]
39
+ and isinstance(args_dict[section][name], str)
40
+ ):
41
+ sec = args_dict[section]
42
+ sec[name] = string_list(sec[name])
43
+
44
+
45
+ class JobConfig:
46
+ """
47
+ A helper class to manage the train configuration.
48
+ Semantics:
49
+ - Default config is loaded from a toml file. If no toml file is provided,
50
+ then the default config is loaded from argparse defaults.
51
+ - if toml file has missing keys, they are filled with argparse defaults.
52
+ - if additional explicit cmd args are provided in addition to the toml
53
+ file, they will override the toml config and the argparse defaults
54
+
55
+ precedence order: cmdline > toml > argparse default
56
+
57
+ Arg parsing semantics:
58
+
59
+ Each argument starts with <prefix>_ which is the section name in the toml file
60
+ followed by name of the option in the toml file. For ex,
61
+ model.name translates to:
62
+ [model]
63
+ name
64
+ in the toml file
65
+ """
66
+
67
+ def __init__(self):
68
+ self.args_dict = None
69
+ # main parser
70
+ self.parser = argparse.ArgumentParser(description="torchtitan arg parser.")
71
+
72
+ self.parser.add_argument(
73
+ "--job.config_file",
74
+ type=str,
75
+ default=None,
76
+ help="Job config file",
77
+ )
78
+
79
+ # job level configs
80
+ self.parser.add_argument(
81
+ "--job.dump_folder",
82
+ type=str,
83
+ default="./torchtitan/outputs",
84
+ help="Folder to dump job outputs",
85
+ )
86
+ self.parser.add_argument(
87
+ "--job.description",
88
+ type=str,
89
+ default="default job",
90
+ help="Description of the job",
91
+ )
92
+ self.parser.add_argument(
93
+ "--job.use_for_integration_test",
94
+ action="store_true",
95
+ help="Add this config to the integration test suite",
96
+ )
97
+ self.parser.add_argument(
98
+ "--job.print_args",
99
+ action="store_true",
100
+ help="Print the args to terminal",
101
+ )
102
+
103
+ # model configs
104
+ self.parser.add_argument(
105
+ "--model.name",
106
+ type=str,
107
+ default="fla",
108
+ help="Which model to train",
109
+ )
110
+ self.parser.add_argument(
111
+ "--model.config",
112
+ type=str,
113
+ default="fla-hub/transformer-1.3B-100B",
114
+ help="Path to the model config",
115
+ )
116
+ self.parser.add_argument(
117
+ "--model.tokenizer_path",
118
+ type=str,
119
+ default="fla-hub/transformer-1.3B-100B",
120
+ help="Tokenizer path",
121
+ )
122
+ self.parser.add_argument(
123
+ "--model.converters",
124
+ type=string_list,
125
+ nargs="+",
126
+ default=[],
127
+ help="""
128
+ Comma separated list of converters to apply to the model.
129
+ For instance, the `float8` converter swaps `torch.nn.Linear`
130
+ with `Float8Linear`. This feature requires you to install 'torchao'
131
+ which can be found here: https://github.com/pytorch/ao
132
+ """,
133
+ )
134
+ self.parser.add_argument(
135
+ "--model.print_after_conversion",
136
+ action="store_true",
137
+ help="""
138
+ If true, model definition will be printed to stdout after all model
139
+ converters have been applied.
140
+ """,
141
+ )
142
+
143
+ # profiling configs
144
+ self.parser.add_argument(
145
+ "--profiling.enable_profiling",
146
+ action="store_true",
147
+ help="Whether to enable pytorch profiler",
148
+ )
149
+ self.parser.add_argument(
150
+ "--profiling.save_traces_folder",
151
+ type=str,
152
+ default="profile_traces",
153
+ help="Trace files location",
154
+ )
155
+ self.parser.add_argument(
156
+ "--profiling.profile_freq",
157
+ type=int,
158
+ default=10,
159
+ help="How often to collect profiler traces, in iterations",
160
+ )
161
+ self.parser.add_argument(
162
+ "--profiling.enable_memory_snapshot",
163
+ action="store_true",
164
+ help="Whether to dump memory snapshot",
165
+ )
166
+ self.parser.add_argument(
167
+ "--profiling.save_memory_snapshot_folder",
168
+ type=str,
169
+ default="memory_snapshot",
170
+ help="Memeory snapshot files location",
171
+ )
172
+
173
+ # optimizer configs
174
+ self.parser.add_argument(
175
+ "--optimizer.name", type=str, default="AdamW", help="Optimizer to use"
176
+ )
177
+ self.parser.add_argument(
178
+ "--optimizer.eps",
179
+ type=float,
180
+ default=1e-8,
181
+ help="Epsilon value for the optimizer.",
182
+ )
183
+ self.parser.add_argument(
184
+ "--optimizer.lr", type=float, default=8e-4, help="Learning rate to use"
185
+ )
186
+ self.parser.add_argument(
187
+ "--optimizer.beta1", type=float, default=0.9,
188
+ help="Exponential moving average hyperparameters to use"
189
+ )
190
+ self.parser.add_argument(
191
+ "--optimizer.beta2", type=float, default=0.95,
192
+ help="Exponential moving average hyperparameters to use"
193
+ )
194
+ self.parser.add_argument(
195
+ "--optimizer.weight_decay", type=float, default=0.1,
196
+ help="Weight decay to use"
197
+ )
198
+ self.parser.add_argument(
199
+ "--optimizer.implementation",
200
+ type=str,
201
+ default="fused",
202
+ choices=["for-loop", "foreach", "fused"],
203
+ help="""
204
+ Specify which optimizer implementation to use:
205
+ - 'fused': Use fused implementation (CUDA only) for best performance.
206
+ - 'foreach': Use some horizontal fusion of tensors for better performance.
207
+ - 'for-loop': Use the default implementation for the optimizer (slowest).
208
+ - more info: https://pytorch.org/docs/stable/optim.html
209
+ """,
210
+ )
211
+ self.parser.add_argument(
212
+ "--optimizer.early_step_in_backward",
213
+ action="store_true",
214
+ help="""
215
+ Whether to apply optimizer in the backward. Caution, optimizer_in_backward
216
+ is not compatible with gradients clipping, users should not call
217
+ register_post_accumulate_grad_hook after the optimizer is built.""",
218
+ )
219
+
220
+ # lr scheduler configs
221
+ self.parser.add_argument(
222
+ "--lr_scheduler.warmup_steps",
223
+ type=int,
224
+ default=200,
225
+ help="Steps for lr scheduler warmup, normally 1/5 of --training.steps",
226
+ )
227
+ self.parser.add_argument(
228
+ "--lr_scheduler.decay_ratio",
229
+ type=float,
230
+ default=None,
231
+ help="""
232
+ Controls the proportion of the training steps allocated to the learning rate decay phase.
233
+
234
+ If `None`, the learning rate will begin decaying immediately after the warmup period.
235
+ Otherwise, the learning rate will remain stable after the warmup period and
236
+ only start decaying during the last `decay_ratio` portion of the total training steps.
237
+
238
+ This is known as the Warmup-Stable-Decay (WSD) schedule, as described in https://arxiv.org/abs/2404.06395.
239
+ """,
240
+ )
241
+ self.parser.add_argument(
242
+ "--lr_scheduler.decay_type",
243
+ type=str,
244
+ default="linear",
245
+ choices=["linear", "sqrt", "cosine"],
246
+ help="""
247
+ Learning rate decay type to use during training:
248
+ - 'linear': linearly decays learning rate from initial to final value
249
+ - 'sqrt': decays learning rate following a 1 minus square root curve
250
+ - 'cosine': smoothly decays learning rate following a cosine curve
251
+ """,
252
+ )
253
+ self.parser.add_argument(
254
+ "--lr_scheduler.lr_min",
255
+ type=float,
256
+ default=0.0,
257
+ help="""
258
+ Min lr ratio for lr scheduler.
259
+
260
+ If provided, the range of decay factor is scaled from 1 to `lr_min`
261
+ to ensure the learning rate does not drop below `optimizer.lr * lr_scheduler.lr_min`.
262
+ """,
263
+ )
264
+
265
+ # training configs
266
+ self.parser.add_argument(
267
+ "--training.batch_size", type=int, default=8, help="Batch size"
268
+ )
269
+ self.parser.add_argument(
270
+ "--training.seq_len", type=int, default=2048, help="Sequence length"
271
+ )
272
+ self.parser.add_argument(
273
+ "--training.context_len",
274
+ type=int,
275
+ default=2048,
276
+ help="Max length allowed for each sequence",
277
+ )
278
+ self.parser.add_argument(
279
+ "--training.varlen",
280
+ action="store_true",
281
+ help="Whether to take sequences of variable length as input",
282
+ )
283
+ self.parser.add_argument(
284
+ "--training.gradient_accumulation_steps",
285
+ type=int,
286
+ default=1,
287
+ help="Number of steps to accumulate gradients before updating parameters",
288
+ )
289
+ self.parser.add_argument(
290
+ "--training.steps",
291
+ type=int,
292
+ default=10000,
293
+ help="How many train steps to run",
294
+ )
295
+ self.parser.add_argument(
296
+ "--training.max_norm",
297
+ type=float,
298
+ default=1.0,
299
+ help="Max norm for gradient clipping",
300
+ )
301
+ self.parser.add_argument(
302
+ "--training.skip_nan_inf",
303
+ action="store_true",
304
+ help="Skip batch updates when NaN or INF gradients are encountered during training",
305
+ )
306
+ self.parser.add_argument(
307
+ "--training.dataset",
308
+ default="HuggingFaceFW/fineweb-edu",
309
+ help="Dataset to use, with comma separated values",
310
+ )
311
+ self.parser.add_argument(
312
+ "--training.dataset_name",
313
+ default=None,
314
+ help="The name of the dataset config, with comma separated values if provided",
315
+ )
316
+ self.parser.add_argument(
317
+ "--training.dataset_split",
318
+ default=None,
319
+ help="Dataset split to use, with comma separated values if provided",
320
+ )
321
+ self.parser.add_argument(
322
+ "--training.data_dir",
323
+ default=None,
324
+ help="Data dirs to use, with comma separated values if provided",
325
+ )
326
+ self.parser.add_argument(
327
+ "--training.data_files",
328
+ default=None,
329
+ help="Data files to use, with comma separated values if provided",
330
+ )
331
+ self.parser.add_argument(
332
+ "--training.data_probs",
333
+ default=None,
334
+ help="Data sampling probabilities, with comma separated values if provided",
335
+ )
336
+ self.parser.add_argument(
337
+ "--training.streaming",
338
+ action="store_true",
339
+ help="Whether to load dataset in streaming mode, used for huge dataset",
340
+ )
341
+ self.parser.add_argument(
342
+ "--training.num_workers",
343
+ type=int,
344
+ default=32,
345
+ help="Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.",
346
+ )
347
+ self.parser.add_argument(
348
+ "--training.prefetch_factor",
349
+ type=int,
350
+ default=2,
351
+ help="Number of batches loaded in advance by each worker."
352
+ "2 means there will be a total of 2 * num_workers batches prefetched across all workers.",
353
+ )
354
+ self.parser.add_argument(
355
+ "--training.data_parallel_replicate_degree",
356
+ type=int,
357
+ default=1,
358
+ help="""
359
+ The `data_parallel_replicate_degree` argument specifies the degree of
360
+ data parallelism for weight replication. When this value is greater
361
+ than 1, weights will be replicated across `data_parallel_replicate_degree`
362
+ ranks. If `data_parallel_shard_degree` is also greater than 1, the parallelism
363
+ method used is HSDP (Hybrid Sharded Data Parallelism). Otherwise, the
364
+ parallelism method used is DDP (Distributed Data Parallelism).
365
+ 1 means disabled.""",
366
+ )
367
+ self.parser.add_argument(
368
+ "--training.data_parallel_shard_degree",
369
+ type=int,
370
+ default=-1,
371
+ help="""
372
+ The `data_parallel_shard_degree` argument specifies the degree of data
373
+ parallelism for weight sharding. When this value is greater than 1, weights
374
+ will be sharded across `data_parallel_shard_degree` ranks. If
375
+ `data_parallel_replicate_degree` is also greater than 1, the parallelism
376
+ method used is HSDP (Hybrid Sharded Data Parallelism). Otherwise, the
377
+ parallelism method used is FSDP (Fully Sharded Data Parallelism).
378
+
379
+ -1 means leftover ranks will be used (After DP_REPLICATE/SP/PP). Note that
380
+ only `data_parallel_shard_degree` can be negative. 1 means disabled.""",
381
+ )
382
+ self.parser.add_argument(
383
+ "--training.enable_cpu_offload",
384
+ action="store_true",
385
+ help="""
386
+ Whether to apply CPU offloading of parameters, gradients, and optimizer states in FSDP""",
387
+ )
388
+ self.parser.add_argument(
389
+ "--training.tensor_parallel_degree",
390
+ type=int,
391
+ default=1,
392
+ help="Tensor Parallelism degree. 1 means disabled.",
393
+ )
394
+ self.parser.add_argument(
395
+ "--training.disable_loss_parallel",
396
+ action="store_true",
397
+ help="Whether to apply loss parallel when sequence parallel is enabled",
398
+ )
399
+ self.parser.add_argument(
400
+ "--training.fsdp_reshard_after_forward",
401
+ type=str,
402
+ default="default",
403
+ choices=["default", "always", "never"],
404
+ help="""
405
+ `reshard_after_forward` specifies the policy for applying `reshard_after_forward`
406
+ within an FSDP setup. `reshard_after_forward` controls parameter behavior after forward,
407
+ trading off memory and communication. See torch's `fully_shard` API for more documentation
408
+ on `reshard_after_forward`.
409
+ The supported policies include "default", "always" and "never":
410
+ - "default" applies default resharding behavior, implementing "smart defaults" for known optimal
411
+ scenarios.
412
+ - "always" will enable `reshard_after_forward` for all forward passes.
413
+ - "never" will disable `reshard_after_forward` for all forward passes.
414
+ """,
415
+ )
416
+ self.parser.add_argument(
417
+ "--training.mixed_precision_param",
418
+ type=str,
419
+ default="bfloat16",
420
+ choices=["bfloat16", "float32"],
421
+ help="""
422
+ torch dtype to use for parameters when applying mixed precision via fully_shard or torch.autocast.
423
+ This feature takes effect via fully_shard when data_parallel_shard_degree > 1 or
424
+ context_parallel_degree > 1; it takes effect via torch.autocast when data_replicate_degree >= 1
425
+ and no other parallelism is enabled, i.e. under DDP or single-device training.
426
+ """,
427
+ )
428
+ self.parser.add_argument(
429
+ "--training.mixed_precision_reduce",
430
+ type=str,
431
+ default="float32",
432
+ choices=["float32"],
433
+ help="""
434
+ torch dtype to use for reductions when applying mixed precision via FSDP.
435
+ This feature only takes effect when data_parallel_shard_degree > 1
436
+ """,
437
+ )
438
+ self.parser.add_argument(
439
+ "--training.compile",
440
+ action="store_true",
441
+ help="Whether to compile the model",
442
+ )
443
+ self.parser.add_argument(
444
+ "--training.gc_freq",
445
+ type=int,
446
+ default=50,
447
+ help="Python garbage control scheduling interval, in steps",
448
+ )
449
+ self.parser.add_argument(
450
+ "--training.seed",
451
+ type=int,
452
+ default=42,
453
+ help="Choose the base RNG seed used for training",
454
+ )
455
+ self.parser.add_argument(
456
+ "--training.deterministic",
457
+ action="store_true",
458
+ help="Use deterministic algorithms wherever possible, may be slower",
459
+ )
460
+ # ------ jinxin ------ #
461
+ self.parser.add_argument(
462
+ "--training.val_times",
463
+ type=int,
464
+ default=0,
465
+ help="Number of times to evaluate val PPL during training. 0 means no intermediate eval. "
466
+ "e.g. 10 means evaluate every (total_steps // 10) steps.",
467
+ )
468
+ self.parser.add_argument(
469
+ "--training.val_data_dir",
470
+ type=str,
471
+ default=None,
472
+ help="Path to the validation data directory containing parquet files. "
473
+ "If None, defaults to 'data/wiki_val/' relative to cwd.",
474
+ )
475
+ # metrics configs
476
+ self.parser.add_argument(
477
+ "--metrics.log_freq",
478
+ type=int,
479
+ default=10,
480
+ help="How often to log metrics to TensorBoard, in iterations",
481
+ )
482
+ self.parser.add_argument(
483
+ "--metrics.enable_tensorboard",
484
+ action="store_true",
485
+ help="Whether to log metrics to TensorBoard",
486
+ )
487
+ self.parser.add_argument(
488
+ "--metrics.disable_color_printing",
489
+ action="store_true",
490
+ help="Whether to disable color printing in logs",
491
+ )
492
+ self.parser.add_argument(
493
+ "--metrics.save_tb_folder",
494
+ type=str,
495
+ default="tb",
496
+ help="Folder to dump TensorBoard states",
497
+ )
498
+ self.parser.add_argument(
499
+ "--metrics.save_for_all_ranks",
500
+ action="store_true",
501
+ default=False,
502
+ help="""
503
+ Whether to save TensorBoard/Wandb metrics only for rank 0 or for all ranks.
504
+ When this option is False and pipeline_parallel_degree is > 1, the metrics
505
+ component uses the 0th rank of the last stage pipeline group, which is the
506
+ only stage that computes loss metrics.
507
+ """,
508
+ )
509
+ self.parser.add_argument(
510
+ "--metrics.enable_wandb",
511
+ action="store_true",
512
+ help="Whether to log metrics to Weights & Biases",
513
+ )
514
+ self.parser.add_argument(
515
+ "--no-metrics.enable_wandb",
516
+ dest="metrics.enable_wandb",
517
+ action="store_false",
518
+ help="Disable Weights & Biases logging (e.g. to avoid disk quota on ~/.cache/wandb)",
519
+ )
520
+
521
+ self.parser.add_argument(
522
+ "--experimental.enable_async_tensor_parallel",
523
+ action="store_true",
524
+ help="Whether to apply async tensor parallel (currently only effective when compile is enabled)",
525
+ )
526
+ self.parser.add_argument(
527
+ "--experimental.pipeline_parallel_degree",
528
+ type=int,
529
+ default=1,
530
+ help="""
531
+ Pipeline Parallelism degree, or number of ranks. 1 means disabled.
532
+ If using looped schedules, this still specifies the number of physical ranks, not the number
533
+ of stages. Stages per rank are inferred from split points degree, and schedule.""",
534
+ )
535
+ self.parser.add_argument(
536
+ "--experimental.pipeline_parallel_split_points",
537
+ type=string_list,
538
+ nargs="+",
539
+ default=[],
540
+ help="""
541
+ Specify comma-separated names of modules to use as the beginning of a split point.
542
+
543
+ e.g. "layers.0,layers.2" will cause the model to be split into 3 stages,
544
+ the first containing all the layers up to layers.0,
545
+ the second containing layers.0 and up to layers.2,
546
+ the third containing layers.2 and all the remaining layers.
547
+
548
+ Note: fully-automated splitting may be enabled in the future,
549
+ but currently the split points must be specified manually.""",
550
+ )
551
+ self.parser.add_argument(
552
+ "--experimental.pipeline_parallel_schedule",
553
+ type=str,
554
+ default="1F1B",
555
+ help="""
556
+ Specify the Pipeline Parallel schedule to use. The supported schedules are:
557
+ https://github.com/pytorch/pytorch/blob/de4c2a3b4e89d96334dc678d1c3f2ae51a6630a0/torch/distributed/pipelining/schedules.py#L2161.
558
+ The schedule must be compatible with the split points and stages_per_rank.
559
+
560
+ Looped schedules (e.g. Interleaved1F1B) require specifying pipeline_parallel_degree = number of ranks,
561
+ and split_points = number of stages - 1
562
+ """,
563
+ )
564
+ self.parser.add_argument(
565
+ "--experimental.pipeline_parallel_schedule_csv",
566
+ type=str,
567
+ default="",
568
+ help="""
569
+ Specify the path to the pipeline parallel schedule csv file to use.
570
+ The pipeline_parallel_schedule argument must be either
571
+ PipelineScheduleSingle, PipelineScheduleMulti, or _PipelineScheduleRuntime.
572
+ """,
573
+ )
574
+
575
+ self.parser.add_argument(
576
+ "--experimental.pipeline_parallel_microbatches",
577
+ type=int,
578
+ default=None,
579
+ help="""
580
+ How many microbatches to split the global training batch into when using pipeline parallelism.
581
+
582
+ The global training batch size must be evenly divisible by the number of microbatches.
583
+
584
+ The default value will be the number of pipeline stages, if unspecified.
585
+ """,
586
+ )
587
+ self.parser.add_argument(
588
+ "--experimental.enable_compiled_autograd",
589
+ action="store_true",
590
+ help="Enable CompiledAutograd to compile the backward.",
591
+ )
592
+ self.parser.add_argument(
593
+ "--experimental.context_parallel_degree",
594
+ type=int,
595
+ default=1,
596
+ help="Context parallelism degree. 1 means disabled.",
597
+ )
598
+ self.parser.add_argument(
599
+ "--experimental.context_parallel_rotate_method",
600
+ type=str,
601
+ default="allgather",
602
+ help="""
603
+ The collective to use in context parallel SDPA for kv shards exchange.
604
+
605
+ 'allgather' means to all-gather all kv shards on ranks after the first sub-SDPA computation,
606
+
607
+ 'alltoall' means to all-to-all shuffle the kv shards.
608
+
609
+ The default value is 'allgather'.
610
+ """,
611
+ )
612
+ # I'm not particularly fond of this. Users can choose to write their own wrapper
613
+ # module and import TorchTitan training loop and execute it, which look cleaner.
614
+ # One reason to provide this option is to allow users to use the existing run script.
615
+ # While the script is pretty trivial now, we may add more logic when integrating
616
+ # with TorchFT.
617
+ # This option is subject to change and may be deleted in the future.
618
+ self.parser.add_argument(
619
+ "--experimental.custom_model_path",
620
+ type=str,
621
+ default="",
622
+ help="""
623
+ The --custom_model_path option allows to specify a custom path to a model module
624
+ that is not natively implemented within TorchTitan.
625
+ Acceptable values are the file system path to the module (e.g., my_models/model_x)
626
+ dotted import module (e.g., some_package.model_x).
627
+ """,
628
+ )
629
+ # checkpointing configs
630
+ self.parser.add_argument(
631
+ "--checkpoint.enable_checkpoint",
632
+ action="store_true",
633
+ help="Whether to enable checkpoint",
634
+ )
635
+ self.parser.add_argument(
636
+ "--checkpoint.folder",
637
+ type=str,
638
+ default="checkpoint",
639
+ help="""
640
+ The folder to store the checkpoints.
641
+ When enable_checkpoint is set to true, checkpoints will be in {--job.dump_folder}/{--checkpoint.folder}.
642
+ """,
643
+ )
644
+ self.parser.add_argument(
645
+ "--checkpoint.initial_load_path", type=str, default=None,
646
+ help="""
647
+ This option specifies the path to the initial checkpoint to load, which is
648
+ particularly useful for resuming training from a previous run with a
649
+ different output path or when loading a checkpoint from a pre-trained model.
650
+ If the checkpoint folder for the current run is not empty,
651
+ located at {--job.dump_folder}/{--checkpoint.folder}, this option will be ignored.
652
+ This feature allows users to load an initial checkpoint from a different folder and
653
+ continue training, saving new checkpoints to the specified folder without affecting
654
+ the existing ones.
655
+
656
+ Note that the path should contain the full path to the checkpoint folder,
657
+ including the step number, if any; for example,
658
+ "//pre_train/checkpoints/llama3/llama3_8b/step_10000".
659
+ """
660
+ )
661
+ self.parser.add_argument(
662
+ "--checkpoint.initial_load_model_weights_only",
663
+ dest='checkpoint.initial_load_model_weights_only', action="store_true", default=True,
664
+ help="""
665
+ This option specifies if only the model weights should be loaded during the initial
666
+ checkpoint load. The option is only used when `initial_load_path` is specified, and
667
+ only applies to a model_weights_only checkpoint. Loading a periodic checkpoint
668
+ may lead to unexpected behavior if this option is set to True.
669
+ If False, the checkpoint at `initial_load_path` is treated as a standard training
670
+ checkpoint, including optimizer and training states.
671
+ The default setting for this option is True. Note that you will have to use
672
+ `--checkpoint.no_initial_load_model_weights_only` to override the default setting.
673
+ """
674
+ )
675
+ self.parser.add_argument(
676
+ "--checkpoint.no_initial_load_model_weights_only",
677
+ dest='checkpoint.initial_load_model_weights_only', action="store_false",
678
+ )
679
+ self.parser.add_argument(
680
+ "--checkpoint.interval",
681
+ type=int,
682
+ default=500,
683
+ help="Checkpointing interval in steps.",
684
+ )
685
+ self.parser.add_argument(
686
+ "--checkpoint.last_save_model_weights_only",
687
+ action="store_true",
688
+ help="""
689
+ When last_save_model_weights_only=True, only model weights will be saved at the end of training,
690
+ the last save. With this, checkpoints can be loaded using `torch.load(..., weights_only=True)`
691
+ after conversion. When last_save_model_weights_only=False, the full checkpoint will be saved.
692
+ A full checkpoint includes model, optimizer and train_state, which can be used to resume training.
693
+ The default value is false.
694
+ """,
695
+ )
696
+ self.parser.add_argument(
697
+ "--checkpoint.export_dtype",
698
+ type=str,
699
+ default="float32",
700
+ choices=["float16", "bfloat16", "float32"],
701
+ help="""
702
+ Converts to the specified precision when training completes and model_weights_only=true.
703
+ Currently supports float32, float16, and bfloat16.
704
+ The default value is float32.
705
+ """,
706
+ )
707
+ self.parser.add_argument(
708
+ "--checkpoint.create_seed_checkpoint",
709
+ action="store_true",
710
+ help="""
711
+ Initializes the full model without applying parallelisms, and then saves it as a seed checkpoint.
712
+ Note: requires user to call train.py without specifying any parallelisms, e.g. NGPU=1.
713
+ Could be implemented as a separate script, but this way shares more code.
714
+ """,
715
+ )
716
+ self.parser.add_argument(
717
+ "--checkpoint.async_mode",
718
+ type=str,
719
+ default="disabled",
720
+ help="""
721
+ Which async checkpoint mode to use. Currently there are 3 different modes.
722
+ 1. "disabled": synchronized checkpointing will be used.
723
+ 2. "async": torch.distributed.checkpoint.async_save will be used.
724
+ 3. "async_with_pinned_mem": this option utilizes a dedicated pinned memory
725
+ space and creates a separate process for faster GPU->CPU transfer
726
+ performance and eliminating GIL contention. The cost is increased CPU
727
+ memory usage. If insufficient CPU memory is available, performance may
728
+ degrade due to memory paging. For most users, "async" should suffice as
729
+ the performance overhead is typically small (on the order of tens of
730
+ seconds) compared to checkpointing frequency. This mode can be employed
731
+ to pursue near-zero checkpointing times (e.g., < 1 second) given
732
+ appropriate hardware support such as ample CPU memory and fast PCIe.
733
+
734
+ "disabled" is the default mode.
735
+ """,
736
+ )
737
+ self.parser.add_argument(
738
+ "--checkpoint.keep_latest_k",
739
+ type=int,
740
+ default=0,
741
+ help="""
742
+ Keeps only the latest k checkpoints, and purging older ones. If 0, keep all checkpoints.
743
+ 0 is the default value. k cannot be 1 as the last one may be in the process of being
744
+ saved. As a result, the metadata of the last one may not be ready yet.
745
+ """,
746
+ )
747
+ self.parser.add_argument(
748
+ "--checkpoint.load_step",
749
+ type=int,
750
+ default=-1,
751
+ help="Load the checkpoint at the specified step. If -1, load the latest checkpoint.",
752
+ )
753
+ self.parser.add_argument(
754
+ "--checkpoint.exclude_from_loading",
755
+ type=string_list,
756
+ nargs="*",
757
+ default=[],
758
+ help="""
759
+ Exclude specific keys from being loaded from the checkpoint.
760
+ Provide a comma-separated list of keys to exclude, e.g. 'optimizer,lr_scheduler,dataloader'.
761
+ This will load the model only, excluding the specified keys.
762
+ """,
763
+ )
764
+ # activation checkpointing configs
765
+ self.parser.add_argument(
766
+ "--activation_checkpoint.mode",
767
+ type=str,
768
+ default="selective",
769
+ help="Type of activation checkpointing to use ['none', 'full', 'selective']",
770
+ )
771
+ self.parser.add_argument(
772
+ "--activation_checkpoint.selective_ac_option",
773
+ type=str,
774
+ default="2", # 2 = checkpoint every other layer
775
+ help="""
776
+ Selective activation checkpointing options ['int', 'op'].
777
+ 'int' (e.g., 2) for every nth layer, or 'op' for op level ac.
778
+ """,
779
+ )
780
+
781
+ self.parser.add_argument(
782
+ "--activation_offload.mode",
783
+ type=str,
784
+ default="none",
785
+ help="""
786
+ if we are using activation offload or not. Options are ['none', 'full'].
787
+ """,
788
+ )
789
+
790
+ # float8 configs
791
+ self.parser.add_argument(
792
+ "--float8.enable_fsdp_float8_all_gather",
793
+ action="store_true",
794
+ help="Whether enable float8 all-gather in FSDP, recommended for tensorwise scaling",
795
+ )
796
+ self.parser.add_argument(
797
+ "--float8.precompute_float8_dynamic_scale_for_fsdp",
798
+ action="store_true",
799
+ help="Whether precompute float8 scales dynamically for FSDP, recommended for tensorwise scaling",
800
+ )
801
+ self.parser.add_argument(
802
+ "--float8.force_recompute_fp8_weight_in_bwd",
803
+ action="store_true",
804
+ help="""
805
+ Whether to force the recomputation of FP8 weights during backward pass.
806
+ When using FSDP with tensorwise scaling, it is recommended to enable
807
+ `force_recompute_fp8_weight_in_bwd` to prevent saving unsharded FP8 weights
808
+ for backward computation.
809
+ """,
810
+ )
811
+ self.parser.add_argument(
812
+ "--float8.recipe_name",
813
+ type=str,
814
+ default=None,
815
+ choices=["tensorwise", "rowwise", "rowwise_with_gw_hp"],
816
+ help="""
817
+ If specified, creates float8 config from recipe name, valid choices are
818
+ `tensorwise`, `rowwise` and `rowwise_with_gw_hp`.
819
+ """,
820
+ )
821
+
822
+ # communications library settings
823
+ self.parser.add_argument(
824
+ "--comm.init_timeout_seconds",
825
+ type=int,
826
+ default=300,
827
+ help="Timeout for communication operations, during initialization and first train step.",
828
+ )
829
+ self.parser.add_argument(
830
+ "--comm.train_timeout_seconds",
831
+ type=int,
832
+ default=100,
833
+ help=(
834
+ "Timeout for communication operations after the first train step -- "
835
+ "usually a tighter bound than during initialization."
836
+ ),
837
+ )
838
+ self.parser.add_argument(
839
+ "--comm.trace_buf_size",
840
+ type=int,
841
+ default=20000,
842
+ help="Flight recorder ring buffer size, >0 means recording by default, 0 means disabled",
843
+ )
844
+
845
+ # memory estimation settings
846
+ self.parser.add_argument(
847
+ "--memory_estimation.enabled",
848
+ help="Whether to estimate memory usage for FSDP",
849
+ action="store_true",
850
+ )
851
+
852
+ self.parser.add_argument(
853
+ "--memory_estimation.disable_fake_mode",
854
+ help="Whether to estimate memory under FakeTensorMode",
855
+ action="store_true",
856
+ )
857
+
858
+ self.parser.add_argument(
859
+ "--fault_tolerance.enable",
860
+ action="store_true",
861
+ help="""
862
+ Enable TorchFT integration. When TorchFT is enabled, HSDP will be used.
863
+ And --fault_tolerance.data_parallel_replicate_degree should be 1 and
864
+ --fault_tolerance.group_size will be used to control the maximum
865
+ replicate group size as the replicate group size is dynamic.
866
+
867
+ Note that this is still an experimental feature.
868
+ """,
869
+ )
870
+
871
+ self.parser.add_argument(
872
+ "--fault_tolerance.replica_id",
873
+ type=int,
874
+ default=0,
875
+ help="The TorchFT replica ID of this run.",
876
+ )
877
+
878
+ self.parser.add_argument(
879
+ "--fault_tolerance.group_size",
880
+ type=int,
881
+ default=0,
882
+ help="""
883
+ The number of TorchFT replicate groups. This number will be used for
884
+ dataloader to split the dataset across the replicate groups and FSDP
885
+ dimension
886
+ """,
887
+ )
888
+
889
+ self.parser.add_argument(
890
+ "--fault_tolerance.min_replica_size",
891
+ type=int,
892
+ default=1,
893
+ help="The minimum number of FT replica for each step.",
894
+ )
895
+
896
+ def to_dict(self):
897
+ return self.args_dict
898
+
899
+ def parse_args(self, args_list: list = sys.argv[1:]):
900
+ args, cmd_args = self.parse_args_from_command_line(args_list)
901
+ config_file = getattr(args, "job.config_file", None)
902
+ # build up a two level dict
903
+ args_dict = self._args_to_two_level_dict(args)
904
+ if config_file is not None:
905
+ try:
906
+ with open(config_file, "rb") as f:
907
+ for k, v in tomllib.load(f).items():
908
+ # to prevent overwrite of non-specified keys
909
+ args_dict[k] |= v
910
+ except (FileNotFoundError, tomllib.TOMLDecodeError) as e:
911
+ logger.exception(
912
+ f"Error while loading the configuration file: {config_file}"
913
+ )
914
+ logger.exception(f"Error details: {str(e)}")
915
+ raise e
916
+
917
+ # Checking string-list arguments are properly split into a list
918
+ # if split-points came from 'args' (from cmd line) it would have already been parsed into a list by that parser
919
+ string_list_argnames = self._get_string_list_argument_names()
920
+ for n in string_list_argnames:
921
+ check_string_list_argument(args_dict, n)
922
+
923
+ # override args dict with cmd_args
924
+ cmd_args_dict = self._args_to_two_level_dict(cmd_args)
925
+ for section, section_args in cmd_args_dict.items():
926
+ for k, v in section_args.items():
927
+ args_dict[section][k] = v
928
+
929
+ self.args_dict = args_dict
930
+
931
+ for k, v in args_dict.items():
932
+ class_type = type(k.title(), (), v)
933
+ setattr(self, k, class_type())
934
+ self._validate_config()
935
+
936
+ def _args_to_two_level_dict(self, args: argparse.Namespace) -> defaultdict:
937
+ args_dict = defaultdict(defaultdict)
938
+ for k, v in vars(args).items():
939
+ first_level_key, second_level_key = k.split(".", 1)
940
+ args_dict[first_level_key][second_level_key] = v
941
+ return args_dict
942
+
943
+ def _validate_config(self) -> None:
944
+ # TODO: Add more mandatory validations
945
+ assert self.model.config
946
+ assert self.model.tokenizer_path
947
+
948
+ def _get_string_list_argument_names(self) -> list[str]:
949
+ """Get the parser argument names of type `string_list`."""
950
+ string_list_args = [
951
+ v.dest for v in self.parser._actions if v.type is string_list
952
+ ]
953
+ return string_list_args
954
+
955
+ def parse_args_from_command_line(
956
+ self, args_list
957
+ ) -> Tuple[argparse.Namespace, argparse.Namespace]:
958
+ """
959
+ Parse command line arguments and return the parsed args and the command line only args
960
+ """
961
+ args = self.parser.parse_args(args_list)
962
+ string_list_argnames = set(self._get_string_list_argument_names())
963
+
964
+ # aux parser to parse the command line only args, with no defaults from main parser
965
+ aux_parser = argparse.ArgumentParser(argument_default=argparse.SUPPRESS)
966
+ for arg, val in vars(args).items():
967
+ if isinstance(val, bool):
968
+ aux_parser.add_argument(
969
+ "--" + arg, action="store_true" if val else "store_false"
970
+ )
971
+ elif arg in string_list_argnames:
972
+ # without this special case, type inference breaks here,
973
+ # since the inferred type is just 'list' and it ends up flattening
974
+ # e.g. from ["layers.0", "layers.1"] into ["l", "a", "y", "e", "r", "s", ".0", ...]
975
+ aux_parser.add_argument("--" + arg, type=string_list)
976
+ else:
977
+ aux_parser.add_argument("--" + arg, type=type(val))
978
+
979
+ cmd_args, _ = aux_parser.parse_known_args(args_list)
980
+
981
+ return args, cmd_args