Add 1b_archs_fwe/transformer_1b_fwe_lion_lr3e_4_b1_0_9_b2_0_99_20260519_061741
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- .gitattributes +9 -0
- 1b_archs_fwe/transformer_1b_fwe_lion_lr3e_4_b1_0_9_b2_0_99_20260519_061741/exp_data/LICENSE +21 -0
- 1b_archs_fwe/transformer_1b_fwe_lion_lr3e_4_b1_0_9_b2_0_99_20260519_061741/exp_data/README.md +519 -0
- 1b_archs_fwe/transformer_1b_fwe_lion_lr3e_4_b1_0_9_b2_0_99_20260519_061741/exp_data/checkpoint/step-30720/.metadata +3 -0
- 1b_archs_fwe/transformer_1b_fwe_lion_lr3e_4_b1_0_9_b2_0_99_20260519_061741/exp_data/checkpoint/step-30720/__0_0.distcp +3 -0
- 1b_archs_fwe/transformer_1b_fwe_lion_lr3e_4_b1_0_9_b2_0_99_20260519_061741/exp_data/checkpoint/step-30720/__1_0.distcp +3 -0
- 1b_archs_fwe/transformer_1b_fwe_lion_lr3e_4_b1_0_9_b2_0_99_20260519_061741/exp_data/checkpoint/step-30720/__2_0.distcp +3 -0
- 1b_archs_fwe/transformer_1b_fwe_lion_lr3e_4_b1_0_9_b2_0_99_20260519_061741/exp_data/checkpoint/step-30720/__3_0.distcp +3 -0
- 1b_archs_fwe/transformer_1b_fwe_lion_lr3e_4_b1_0_9_b2_0_99_20260519_061741/exp_data/checkpoint/step-30720/__4_0.distcp +3 -0
- 1b_archs_fwe/transformer_1b_fwe_lion_lr3e_4_b1_0_9_b2_0_99_20260519_061741/exp_data/checkpoint/step-30720/__5_0.distcp +3 -0
- 1b_archs_fwe/transformer_1b_fwe_lion_lr3e_4_b1_0_9_b2_0_99_20260519_061741/exp_data/checkpoint/step-30720/__6_0.distcp +3 -0
- 1b_archs_fwe/transformer_1b_fwe_lion_lr3e_4_b1_0_9_b2_0_99_20260519_061741/exp_data/checkpoint/step-30720/__7_0.distcp +3 -0
- 1b_archs_fwe/transformer_1b_fwe_lion_lr3e_4_b1_0_9_b2_0_99_20260519_061741/exp_data/config.json +34 -0
- 1b_archs_fwe/transformer_1b_fwe_lion_lr3e_4_b1_0_9_b2_0_99_20260519_061741/exp_data/configs/blt_transformer_1000hash.json +98 -0
- 1b_archs_fwe/transformer_1b_fwe_lion_lr3e_4_b1_0_9_b2_0_99_20260519_061741/exp_data/configs/blt_transformer_1_5B.json +99 -0
- 1b_archs_fwe/transformer_1b_fwe_lion_lr3e_4_b1_0_9_b2_0_99_20260519_061741/exp_data/configs/blt_transformer_380M.json +98 -0
- 1b_archs_fwe/transformer_1b_fwe_lion_lr3e_4_b1_0_9_b2_0_99_20260519_061741/exp_data/configs/delta_net_1B.json +29 -0
- 1b_archs_fwe/transformer_1b_fwe_lion_lr3e_4_b1_0_9_b2_0_99_20260519_061741/exp_data/configs/delta_net_340M.json +26 -0
- 1b_archs_fwe/transformer_1b_fwe_lion_lr3e_4_b1_0_9_b2_0_99_20260519_061741/exp_data/configs/gated_deltanet_1B.json +22 -0
- 1b_archs_fwe/transformer_1b_fwe_lion_lr3e_4_b1_0_9_b2_0_99_20260519_061741/exp_data/configs/gated_deltanet_340M.json +22 -0
- 1b_archs_fwe/transformer_1b_fwe_lion_lr3e_4_b1_0_9_b2_0_99_20260519_061741/exp_data/configs/gated_deltanet_h_340M.json +28 -0
- 1b_archs_fwe/transformer_1b_fwe_lion_lr3e_4_b1_0_9_b2_0_99_20260519_061741/exp_data/configs/gla_1B.json +24 -0
- 1b_archs_fwe/transformer_1b_fwe_lion_lr3e_4_b1_0_9_b2_0_99_20260519_061741/exp_data/configs/gla_340M.json +24 -0
- 1b_archs_fwe/transformer_1b_fwe_lion_lr3e_4_b1_0_9_b2_0_99_20260519_061741/exp_data/configs/gla_7B.json +25 -0
- 1b_archs_fwe/transformer_1b_fwe_lion_lr3e_4_b1_0_9_b2_0_99_20260519_061741/exp_data/configs/gsa_340M.json +29 -0
- 1b_archs_fwe/transformer_1b_fwe_lion_lr3e_4_b1_0_9_b2_0_99_20260519_061741/exp_data/configs/mergenet_340M.json +34 -0
- 1b_archs_fwe/transformer_1b_fwe_lion_lr3e_4_b1_0_9_b2_0_99_20260519_061741/exp_data/configs/mergenet_64M.json +34 -0
- 1b_archs_fwe/transformer_1b_fwe_lion_lr3e_4_b1_0_9_b2_0_99_20260519_061741/exp_data/configs/qwen3_next_1B.json +44 -0
- 1b_archs_fwe/transformer_1b_fwe_lion_lr3e_4_b1_0_9_b2_0_99_20260519_061741/exp_data/configs/qwen3_next_350M.json +44 -0
- 1b_archs_fwe/transformer_1b_fwe_lion_lr3e_4_b1_0_9_b2_0_99_20260519_061741/exp_data/configs/transformer_1B.json +22 -0
- 1b_archs_fwe/transformer_1b_fwe_lion_lr3e_4_b1_0_9_b2_0_99_20260519_061741/exp_data/configs/transformer_340M.json +18 -0
- 1b_archs_fwe/transformer_1b_fwe_lion_lr3e_4_b1_0_9_b2_0_99_20260519_061741/exp_data/configs/transformer_7B.json +21 -0
- 1b_archs_fwe/transformer_1b_fwe_lion_lr3e_4_b1_0_9_b2_0_99_20260519_061741/exp_data/flame/__init__.py +1 -0
- 1b_archs_fwe/transformer_1b_fwe_lion_lr3e_4_b1_0_9_b2_0_99_20260519_061741/exp_data/flame/__pycache__/__init__.cpython-310.pyc +0 -0
- 1b_archs_fwe/transformer_1b_fwe_lion_lr3e_4_b1_0_9_b2_0_99_20260519_061741/exp_data/flame/__pycache__/__init__.cpython-311.pyc +0 -0
- 1b_archs_fwe/transformer_1b_fwe_lion_lr3e_4_b1_0_9_b2_0_99_20260519_061741/exp_data/flame/__pycache__/__init__.cpython-313.pyc +0 -0
- 1b_archs_fwe/transformer_1b_fwe_lion_lr3e_4_b1_0_9_b2_0_99_20260519_061741/exp_data/flame/__pycache__/config_manager.cpython-310.pyc +0 -0
- 1b_archs_fwe/transformer_1b_fwe_lion_lr3e_4_b1_0_9_b2_0_99_20260519_061741/exp_data/flame/__pycache__/config_manager.cpython-311.pyc +0 -0
- 1b_archs_fwe/transformer_1b_fwe_lion_lr3e_4_b1_0_9_b2_0_99_20260519_061741/exp_data/flame/__pycache__/data.cpython-310.pyc +0 -0
- 1b_archs_fwe/transformer_1b_fwe_lion_lr3e_4_b1_0_9_b2_0_99_20260519_061741/exp_data/flame/__pycache__/data.cpython-311.pyc +0 -0
- 1b_archs_fwe/transformer_1b_fwe_lion_lr3e_4_b1_0_9_b2_0_99_20260519_061741/exp_data/flame/__pycache__/train.cpython-310.pyc +0 -0
- 1b_archs_fwe/transformer_1b_fwe_lion_lr3e_4_b1_0_9_b2_0_99_20260519_061741/exp_data/flame/__pycache__/train.cpython-311.pyc +0 -0
- 1b_archs_fwe/transformer_1b_fwe_lion_lr3e_4_b1_0_9_b2_0_99_20260519_061741/exp_data/flame/__pycache__/train.cpython-313.pyc +0 -0
- 1b_archs_fwe/transformer_1b_fwe_lion_lr3e_4_b1_0_9_b2_0_99_20260519_061741/exp_data/flame/c4_test.py +603 -0
- 1b_archs_fwe/transformer_1b_fwe_lion_lr3e_4_b1_0_9_b2_0_99_20260519_061741/exp_data/flame/components/__init__.py +0 -0
- 1b_archs_fwe/transformer_1b_fwe_lion_lr3e_4_b1_0_9_b2_0_99_20260519_061741/exp_data/flame/components/__pycache__/__init__.cpython-310.pyc +0 -0
- 1b_archs_fwe/transformer_1b_fwe_lion_lr3e_4_b1_0_9_b2_0_99_20260519_061741/exp_data/flame/components/__pycache__/__init__.cpython-311.pyc +0 -0
- 1b_archs_fwe/transformer_1b_fwe_lion_lr3e_4_b1_0_9_b2_0_99_20260519_061741/exp_data/flame/components/__pycache__/checkpoint.cpython-310.pyc +0 -0
- 1b_archs_fwe/transformer_1b_fwe_lion_lr3e_4_b1_0_9_b2_0_99_20260519_061741/exp_data/flame/components/__pycache__/checkpoint.cpython-311.pyc +0 -0
- 1b_archs_fwe/transformer_1b_fwe_lion_lr3e_4_b1_0_9_b2_0_99_20260519_061741/exp_data/flame/components/checkpoint.py +59 -0
.gitattributes
CHANGED
|
@@ -681,3 +681,12 @@ transformer_pp_340m_c4/transformer_pp_340m_c4_valc4_lion_lr1e_4_b1_0_9_b2_0_99_e
|
|
| 681 |
1b_archs_fwe/gla_1b_fwe_adan_lr3e_3_b1_0_9_b2_0_92_b3_0_99_eps_1e_8_20260518_235004/exp_data/checkpoint/step-30720/__5_0.distcp filter=lfs diff=lfs merge=lfs -text
|
| 682 |
1b_archs_fwe/gla_1b_fwe_adan_lr3e_3_b1_0_9_b2_0_92_b3_0_99_eps_1e_8_20260518_235004/exp_data/checkpoint/step-30720/__6_0.distcp filter=lfs diff=lfs merge=lfs -text
|
| 683 |
1b_archs_fwe/gla_1b_fwe_adan_lr3e_3_b1_0_9_b2_0_92_b3_0_99_eps_1e_8_20260518_235004/exp_data/checkpoint/step-30720/__7_0.distcp filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 681 |
1b_archs_fwe/gla_1b_fwe_adan_lr3e_3_b1_0_9_b2_0_92_b3_0_99_eps_1e_8_20260518_235004/exp_data/checkpoint/step-30720/__5_0.distcp filter=lfs diff=lfs merge=lfs -text
|
| 682 |
1b_archs_fwe/gla_1b_fwe_adan_lr3e_3_b1_0_9_b2_0_92_b3_0_99_eps_1e_8_20260518_235004/exp_data/checkpoint/step-30720/__6_0.distcp filter=lfs diff=lfs merge=lfs -text
|
| 683 |
1b_archs_fwe/gla_1b_fwe_adan_lr3e_3_b1_0_9_b2_0_92_b3_0_99_eps_1e_8_20260518_235004/exp_data/checkpoint/step-30720/__7_0.distcp filter=lfs diff=lfs merge=lfs -text
|
| 684 |
+
1b_archs_fwe/transformer_1b_fwe_lion_lr3e_4_b1_0_9_b2_0_99_20260519_061741/exp_data/checkpoint/step-30720/.metadata filter=lfs diff=lfs merge=lfs -text
|
| 685 |
+
1b_archs_fwe/transformer_1b_fwe_lion_lr3e_4_b1_0_9_b2_0_99_20260519_061741/exp_data/checkpoint/step-30720/__0_0.distcp filter=lfs diff=lfs merge=lfs -text
|
| 686 |
+
1b_archs_fwe/transformer_1b_fwe_lion_lr3e_4_b1_0_9_b2_0_99_20260519_061741/exp_data/checkpoint/step-30720/__1_0.distcp filter=lfs diff=lfs merge=lfs -text
|
| 687 |
+
1b_archs_fwe/transformer_1b_fwe_lion_lr3e_4_b1_0_9_b2_0_99_20260519_061741/exp_data/checkpoint/step-30720/__2_0.distcp filter=lfs diff=lfs merge=lfs -text
|
| 688 |
+
1b_archs_fwe/transformer_1b_fwe_lion_lr3e_4_b1_0_9_b2_0_99_20260519_061741/exp_data/checkpoint/step-30720/__3_0.distcp filter=lfs diff=lfs merge=lfs -text
|
| 689 |
+
1b_archs_fwe/transformer_1b_fwe_lion_lr3e_4_b1_0_9_b2_0_99_20260519_061741/exp_data/checkpoint/step-30720/__4_0.distcp filter=lfs diff=lfs merge=lfs -text
|
| 690 |
+
1b_archs_fwe/transformer_1b_fwe_lion_lr3e_4_b1_0_9_b2_0_99_20260519_061741/exp_data/checkpoint/step-30720/__5_0.distcp filter=lfs diff=lfs merge=lfs -text
|
| 691 |
+
1b_archs_fwe/transformer_1b_fwe_lion_lr3e_4_b1_0_9_b2_0_99_20260519_061741/exp_data/checkpoint/step-30720/__6_0.distcp filter=lfs diff=lfs merge=lfs -text
|
| 692 |
+
1b_archs_fwe/transformer_1b_fwe_lion_lr3e_4_b1_0_9_b2_0_99_20260519_061741/exp_data/checkpoint/step-30720/__7_0.distcp filter=lfs diff=lfs merge=lfs -text
|
1b_archs_fwe/transformer_1b_fwe_lion_lr3e_4_b1_0_9_b2_0_99_20260519_061741/exp_data/LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2023-2025 Songlin Yang, Yu Zhang
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
1b_archs_fwe/transformer_1b_fwe_lion_lr3e_4_b1_0_9_b2_0_99_20260519_061741/exp_data/README.md
ADDED
|
@@ -0,0 +1,519 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<div align="center">
|
| 2 |
+
|
| 3 |
+
# 🔥 Flame: Flash Language Modeling Made Easy
|
| 4 |
+
|
| 5 |
+
[](https://deepwiki.com/fla-org/flame)
|
| 6 |
+
|
| 7 |
+
</div>
|
| 8 |
+
|
| 9 |
+
Welcome to 🔥 `flame`, a minimal and efficient framework built on `torchtitan` for language models with blazing efficiency.
|
| 10 |
+
|
| 11 |
+
**Feature Highlights:**
|
| 12 |
+
|
| 13 |
+
- 🚀 Minimal, easy-to-use, extensible training framework
|
| 14 |
+
- 🤗 Seamless integration with `fla` and `transformers`
|
| 15 |
+
- 🔄 Zero-cost data preprocessing: online tokenization, dataset shuffling, and multiple datasets support
|
| 16 |
+
- 🔮 4D parallelism (coming soon)
|
| 17 |
+
|
| 18 |
+
## Setup
|
| 19 |
+
|
| 20 |
+
To get started, clone the `flame` repository and install the required dependencies:
|
| 21 |
+
|
| 22 |
+
```bash
|
| 23 |
+
git clone https://github.com/fla-org/flame.git
|
| 24 |
+
cd flame
|
| 25 |
+
pip install .
|
| 26 |
+
```
|
| 27 |
+
|
| 28 |
+
Install the latest version of fla
|
| 29 |
+
```
|
| 30 |
+
pip uninstall flash-linear-attention && pip install -U --no-use-pep517 git+https://github.com/fla-org/flash-linear-attention
|
| 31 |
+
```
|
| 32 |
+
|
| 33 |
+
[Important] Install specific version of torchtitan
|
| 34 |
+
```
|
| 35 |
+
pip install git+https://github.com/pytorch/torchtitan.git@0b44d4c
|
| 36 |
+
```
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
## Dataset Preparation
|
| 40 |
+
To download the dataset to your local disk, create a new Python file with the following content and execute it:
|
| 41 |
+
|
| 42 |
+
```py
|
| 43 |
+
from datasets import load_dataset
|
| 44 |
+
|
| 45 |
+
# load fineweb-edu with parallel processing
|
| 46 |
+
dataset = load_dataset("HuggingFaceFW/fineweb-edu", name="default", num_proc=64, cache_dir="/your/cache/path")
|
| 47 |
+
|
| 48 |
+
# or load a subset with roughly 100B tokens, suitable for small- or medium-sized experiments
|
| 49 |
+
dataset = load_dataset("HuggingFaceFW/fineweb-edu", name="sample-100BT", num_proc=64, cache_dir="/your/cache/path")
|
| 50 |
+
```
|
| 51 |
+
|
| 52 |
+
## Training Recipes
|
| 53 |
+
|
| 54 |
+
Here's an example of training a 340M FLA Transformer model with a LLaMA-like architecture from scratch on a 100BT subset of the Fineweb-edu corpus ~~in streaming mode~~. (Do not use streaming mode if you are concerned about resuming training.)
|
| 55 |
+
|
| 56 |
+
> [!WARNING]
|
| 57 |
+
> If the dataset is not downloaded beforehand, the streaming mode will attempt to fetch it from a remote server and download it on-the-fly, which can be highly unstable during training due to network issues.
|
| 58 |
+
> For stable training, ensure the dataset is downloaded locally (see [**Dataset Preparation**](#dataset-preparation)). Otherwise, we assume you are only testing the new corpus.
|
| 59 |
+
|
| 60 |
+
```sh
|
| 61 |
+
bash train.sh \
|
| 62 |
+
--job.config_file flame/models/fla.toml \
|
| 63 |
+
--job.dump_folder exp/transformer-340M-4K-10B/batch1.seqlen65536.context4096.warmup1024.update1.steps20480.lr1e-3.cosine \
|
| 64 |
+
--model.config configs/transformer_340M.json \
|
| 65 |
+
--model.tokenizer_path fla-hub/transformer-1.3B-100B \
|
| 66 |
+
--optimizer.name AdamW \
|
| 67 |
+
--optimizer.eps 1e-15 \
|
| 68 |
+
--optimizer.lr 1e-3 \
|
| 69 |
+
--lr_scheduler.warmup_steps 1024 \
|
| 70 |
+
--lr_scheduler.lr_min 0.1 \
|
| 71 |
+
--lr_scheduler.decay_type cosine \
|
| 72 |
+
--training.batch_size 1 \
|
| 73 |
+
--training.seq_len 65536 \
|
| 74 |
+
--training.context_len 4096 \
|
| 75 |
+
--training.varlen \
|
| 76 |
+
--training.gradient_accumulation_steps 1 \
|
| 77 |
+
--training.steps 20480 \
|
| 78 |
+
--training.max_norm 1.0 \
|
| 79 |
+
--training.skip_nan_inf \
|
| 80 |
+
--training.dataset HuggingFaceFW/fineweb-edu \
|
| 81 |
+
--training.dataset_name sample-100BT \
|
| 82 |
+
--training.dataset_split train \
|
| 83 |
+
--training.num_workers 32 \
|
| 84 |
+
--training.prefetch_factor 2 \
|
| 85 |
+
--training.seed 42 \
|
| 86 |
+
--training.compile \
|
| 87 |
+
--checkpoint.interval 2048 \
|
| 88 |
+
--checkpoint.load_step -1 \
|
| 89 |
+
--checkpoint.keep_latest_k 2 \
|
| 90 |
+
--metrics.log_freq 1
|
| 91 |
+
```
|
| 92 |
+
|
| 93 |
+
You can specify the number of GPUs by setting the environment variable `NGPU`, which defaults to 8.
|
| 94 |
+
**For single-GPU debugging, set `NGPU=1`.**
|
| 95 |
+
|
| 96 |
+
We provide several [config files](https://github.com/fla-org/flame/tree/main/configs) for different models.
|
| 97 |
+
By default, the learning rate is set to 1e-3 with a cosine scheduler. Other schedulers, such as WSD (wsd), are also supported.
|
| 98 |
+
|
| 99 |
+
**Key parameters:**
|
| 100 |
+
- `--lr_scheduler.decay_ratio`: The proportion of the steps allocated to the decay phase. The learning rate will remain stable after the warmup period and only start decaying during the last `decay_ratio` portion of the total training steps, which is known as the Warmup-Stable-Decay (WSD) schedule.
|
| 101 |
+
- `--lr_scheduler.warmup_steps`: The number of steps for the learning rate warmup phase.
|
| 102 |
+
- `--training.steps`: Total number of training steps.
|
| 103 |
+
- `--training.batch_size`: Batch size per device, must be 1 if `--training.varlen` is set.
|
| 104 |
+
- `--training.seq_len`: The length of each sequence in the batch, which is concatenated from multiple samples.
|
| 105 |
+
- `--training.context_len`: The max allowed length of a sample. For non-varlen mode, this is equivalent to `seq_len`.
|
| 106 |
+
- `--training.varlen`: Whether to conduct variable-length sequence training.
|
| 107 |
+
- `--training.gradient_accumulation_steps`: Number of gradient accumulation steps.
|
| 108 |
+
|
| 109 |
+
> [!WARNING]
|
| 110 |
+
> The total number of tokens processed per batch, referred to as `global_batch_size`, is calculated as batch_size × gradient_accumulation_steps × num_gpus.
|
| 111 |
+
> Each step processes `global_batch_size * seq_len` tokens.
|
| 112 |
+
> Monitor the value of `global_batch_size`, `warmup_steps`, and `steps` carefully when modifying any of the hyperparameters!
|
| 113 |
+
|
| 114 |
+
For a detailed explanation of all parameters, run:
|
| 115 |
+
|
| 116 |
+
```sh
|
| 117 |
+
bash train.sh -h
|
| 118 |
+
```
|
| 119 |
+
|
| 120 |
+
<details>
|
| 121 |
+
<summary>Usage</summary>
|
| 122 |
+
|
| 123 |
+
```py
|
| 124 |
+
options:
|
| 125 |
+
-h, --help show this help message and exit
|
| 126 |
+
--job.config_file JOB.CONFIG_FILE
|
| 127 |
+
Job config file
|
| 128 |
+
--job.dump_folder JOB.DUMP_FOLDER
|
| 129 |
+
Folder to dump job outputs
|
| 130 |
+
--job.description JOB.DESCRIPTION
|
| 131 |
+
Description of the job
|
| 132 |
+
--job.use_for_integration_test
|
| 133 |
+
Add this config to the integration test suite
|
| 134 |
+
--job.print_args Print the args to terminal
|
| 135 |
+
--model.config MODEL.CONFIG
|
| 136 |
+
Path to the model config
|
| 137 |
+
--model.norm_type MODEL.NORM_TYPE
|
| 138 |
+
Type of layer normalization to use [layernorm,
|
| 139 |
+
np_layernorm, rmsnorm, fused_rmsnorm]
|
| 140 |
+
--model.tokenizer_path MODEL.TOKENIZER_PATH
|
| 141 |
+
Tokenizer path
|
| 142 |
+
--profiling.enable_profiling
|
| 143 |
+
Whether to enable pytorch profiler
|
| 144 |
+
--profiling.save_traces_folder PROFILING.SAVE_TRACES_FOLDER
|
| 145 |
+
Trace files location
|
| 146 |
+
--profiling.profile_freq PROFILING.PROFILE_FREQ
|
| 147 |
+
How often to collect profiler traces, in iterations
|
| 148 |
+
--profiling.enable_memory_snapshot
|
| 149 |
+
Whether to dump memory snapshot
|
| 150 |
+
--profiling.save_memory_snapshot_folder PROFILING.SAVE_MEMORY_SNAPSHOT_FOLDER
|
| 151 |
+
Memeory snapshot files location
|
| 152 |
+
--optimizer.name OPTIMIZER.NAME
|
| 153 |
+
Optimizer to use
|
| 154 |
+
--optimizer.eps OPTIMIZER.EPS
|
| 155 |
+
Epsilon value for the optimizer.
|
| 156 |
+
--optimizer.fused Whether the fused implementation(CUDA only) is used.
|
| 157 |
+
--optimizer.scheduler {wsd,cosine,linear}
|
| 158 |
+
Scheduler to use. Currently supported: wsd, cosine,
|
| 159 |
+
and linear.
|
| 160 |
+
--optimizer.lr OPTIMIZER.LR
|
| 161 |
+
Learning rate to use
|
| 162 |
+
--optimizer.min_lr_ratio OPTIMIZER.MIN_LR_RATIO
|
| 163 |
+
Min lr ratio for lr scheduler
|
| 164 |
+
--optimizer.early_step_in_backward
|
| 165 |
+
Whether to apply optimizer in the backward. Caution,
|
| 166 |
+
optimizer_in_backward is not compatible with gradients
|
| 167 |
+
clipping, users should not call
|
| 168 |
+
register_post_accumulate_grad_hook after the optimizer
|
| 169 |
+
is built.
|
| 170 |
+
--training.batch_size TRAINING.BATCH_SIZE
|
| 171 |
+
Batch size
|
| 172 |
+
--training.seq_len TRAINING.SEQ_LEN
|
| 173 |
+
Sequence length
|
| 174 |
+
--training.context_len TRAINING.CONTEXT_LEN
|
| 175 |
+
Max length allowed for each sequence
|
| 176 |
+
--training.varlen Whether to take sequences of variable length as input
|
| 177 |
+
--training.warmup_steps TRAINING.WARMUP_STEPS
|
| 178 |
+
Steps for lr scheduler warmup, normally 1/5 of
|
| 179 |
+
--training.steps
|
| 180 |
+
--training.gradient_accumulation_steps TRAINING.GRADIENT_ACCUMULATION_STEPS
|
| 181 |
+
Number of steps to accumulate gradients before
|
| 182 |
+
updating parameters
|
| 183 |
+
--training.steps TRAINING.STEPS
|
| 184 |
+
How many train steps to run
|
| 185 |
+
--training.max_norm TRAINING.MAX_NORM
|
| 186 |
+
Max norm for gradient clipping
|
| 187 |
+
--training.skip_nan_inf
|
| 188 |
+
Skip batch updates when NaN or INF gradients are
|
| 189 |
+
encountered during training
|
| 190 |
+
--training.dataset TRAINING.DATASET
|
| 191 |
+
Dataset to use, with comma separated values
|
| 192 |
+
--training.dataset_name TRAINING.DATASET_NAME
|
| 193 |
+
The name of the dataset config, with comma separated
|
| 194 |
+
values if provided
|
| 195 |
+
--training.dataset_split TRAINING.DATASET_SPLIT
|
| 196 |
+
Dataset split to use, with comma separated values if
|
| 197 |
+
provided
|
| 198 |
+
--training.data_dir TRAINING.DATA_DIR
|
| 199 |
+
Data dirs to use, with comma separated values if
|
| 200 |
+
provided
|
| 201 |
+
--training.data_files TRAINING.DATA_FILES
|
| 202 |
+
Data files to use, with comma separated values if
|
| 203 |
+
provided
|
| 204 |
+
--training.data_probs TRAINING.DATA_PROBS
|
| 205 |
+
Data sampling probabilities, with comma separated
|
| 206 |
+
values if provided
|
| 207 |
+
--training.streaming Whether to load dataset in streaming mode, used for
|
| 208 |
+
huge dataset
|
| 209 |
+
--training.num_workers TRAINING.NUM_WORKERS
|
| 210 |
+
Number of subprocesses to use for data loading. 0
|
| 211 |
+
means that the data will be loaded in the main
|
| 212 |
+
process.
|
| 213 |
+
--training.prefetch_factor TRAINING.PREFETCH_FACTOR
|
| 214 |
+
Number of batches loaded in advance by each worker.2
|
| 215 |
+
means there will be a total of 2 * num_workers batches
|
| 216 |
+
prefetched across all workers.
|
| 217 |
+
--training.data_parallel_replicate_degree TRAINING.DATA_PARALLEL_REPLICATE_DEGREE
|
| 218 |
+
The `data_parallel_replicate_degree` argument
|
| 219 |
+
specifies the degree of data parallelism for weight
|
| 220 |
+
replication. When this value is greater than 1,
|
| 221 |
+
weights will be replicated across
|
| 222 |
+
`data_parallel_replicate_degree` ranks. If
|
| 223 |
+
`data_parallel_shard_degree` is also greater than 1,
|
| 224 |
+
the parallelism method used is HSDP (Hybrid Sharded
|
| 225 |
+
Data Parallelism). Otherwise, the parallelism method
|
| 226 |
+
used is DDP (Distributed Data Parallelism). 1 means
|
| 227 |
+
disabled.
|
| 228 |
+
--training.data_parallel_shard_degree TRAINING.DATA_PARALLEL_SHARD_DEGREE
|
| 229 |
+
The `data_parallel_shard_degree` argument specifies
|
| 230 |
+
the degree of data parallelism for weight sharding.
|
| 231 |
+
When this value is greater than 1, weights will be
|
| 232 |
+
sharded across `data_parallel_shard_degree` ranks. If
|
| 233 |
+
`data_parallel_replicate_degree` is also greater than
|
| 234 |
+
1, the parallelism method used is HSDP (Hybrid Sharded
|
| 235 |
+
Data Parallelism). Otherwise, the parallelism method
|
| 236 |
+
used is FSDP (Fully Sharded Data Parallelism). -1
|
| 237 |
+
means leftover ranks will be used (After
|
| 238 |
+
DP_REPLICATE/SP/PP). Note that only
|
| 239 |
+
`data_parallel_shard_degree` can be negative. 1 means
|
| 240 |
+
disabled.
|
| 241 |
+
--training.enable_cpu_offload
|
| 242 |
+
Whether to apply CPU offloading of parameters,
|
| 243 |
+
gradients, and optimizer states in FSDP
|
| 244 |
+
--training.tensor_parallel_degree TRAINING.TENSOR_PARALLEL_DEGREE
|
| 245 |
+
Tensor Parallelism degree. 1 means disabled.
|
| 246 |
+
--training.disable_loss_parallel
|
| 247 |
+
Whether to apply loss parallel when sequence parallel
|
| 248 |
+
is enabled
|
| 249 |
+
--training.mixed_precision_param {bfloat16,float32}
|
| 250 |
+
torch dtype to use for parameters when applying mixed
|
| 251 |
+
precision via FSDP. This feature only takes effect
|
| 252 |
+
when data_parallel_shard_degree > 1
|
| 253 |
+
--training.mixed_precision_reduce {float32}
|
| 254 |
+
torch dtype to use for reductions when applying mixed
|
| 255 |
+
precision via FSDP. This feature only takes effect
|
| 256 |
+
when data_parallel_shard_degree > 1
|
| 257 |
+
--training.compile Whether to compile the model
|
| 258 |
+
--training.gc_freq TRAINING.GC_FREQ
|
| 259 |
+
Python garbage control scheduling interval, in steps
|
| 260 |
+
--training.seed TRAINING.SEED
|
| 261 |
+
Choose the base RNG seed used for training
|
| 262 |
+
--training.deterministic
|
| 263 |
+
Use deterministic algorithms wherever possible, may be
|
| 264 |
+
slower
|
| 265 |
+
--metrics.log_freq METRICS.LOG_FREQ
|
| 266 |
+
How often to log metrics to TensorBoard, in iterations
|
| 267 |
+
--metrics.enable_tensorboard
|
| 268 |
+
Whether to log metrics to TensorBoard
|
| 269 |
+
--metrics.disable_color_printing
|
| 270 |
+
Whether to disable color printing in logs
|
| 271 |
+
--metrics.save_tb_folder METRICS.SAVE_TB_FOLDER
|
| 272 |
+
Folder to dump TensorBoard states
|
| 273 |
+
--metrics.rank_0_only
|
| 274 |
+
Whether to save TensorBoard metrics only for rank 0 or
|
| 275 |
+
for all ranks. When pipeline_parallel_degree is > 1,
|
| 276 |
+
this option uses the 0th rank of the last stage
|
| 277 |
+
pipeline group, which is the only stage that computes
|
| 278 |
+
loss metrics.
|
| 279 |
+
--metrics.enable_wandb
|
| 280 |
+
Whether to log metrics to Weights & Biases
|
| 281 |
+
--experimental.enable_async_tensor_parallel
|
| 282 |
+
Whether to apply async tensor parallel (currently only
|
| 283 |
+
effective when compile is enabled)
|
| 284 |
+
--experimental.pipeline_parallel_degree EXPERIMENTAL.PIPELINE_PARALLEL_DEGREE
|
| 285 |
+
Pipeline Parallelism degree, or number of ranks. 1
|
| 286 |
+
means disabled. If using looped schedules, this still
|
| 287 |
+
specifies the number of physical ranks, not the number
|
| 288 |
+
of stages. Stages per rank are inferred from split
|
| 289 |
+
points degree, and schedule.
|
| 290 |
+
--experimental.pipeline_parallel_split_points EXPERIMENTAL.PIPELINE_PARALLEL_SPLIT_POINTS [EXPERIMENTAL.PIPELINE_PARALLEL_SPLIT_POINTS ...]
|
| 291 |
+
Specify comma-separated names of modules to use as the
|
| 292 |
+
beginning of a split point. e.g. "layers.0,layers.2"
|
| 293 |
+
will cause the model to be split into 3 stages, the
|
| 294 |
+
first containing all the layers up to layers.0, the
|
| 295 |
+
second containing layers.0 and up to layers.2, the
|
| 296 |
+
third containing layers.2 and all the remaining
|
| 297 |
+
layers. Note: fully-automated splitting may be enabled
|
| 298 |
+
in the future, but currently the split points must be
|
| 299 |
+
specified manually.
|
| 300 |
+
--experimental.pipeline_parallel_schedule EXPERIMENTAL.PIPELINE_PARALLEL_SCHEDULE
|
| 301 |
+
Specify the Pipeline Parallel schedule to use. The
|
| 302 |
+
supported schedules are: https://github.com/pytorch/py
|
| 303 |
+
torch/blob/de4c2a3b4e89d96334dc678d1c3f2ae51a6630a0/to
|
| 304 |
+
rch/distributed/pipelining/schedules.py#L2161. The
|
| 305 |
+
schedule must be compatible with the split points and
|
| 306 |
+
stages_per_rank. Looped schedules (e.g.
|
| 307 |
+
Interleaved1F1B) require specifying
|
| 308 |
+
pipeline_parallel_degree = number of ranks, and
|
| 309 |
+
split_points = number of stages - 1
|
| 310 |
+
--experimental.pipeline_parallel_schedule_csv EXPERIMENTAL.PIPELINE_PARALLEL_SCHEDULE_CSV
|
| 311 |
+
Specify the path to the pipeline parallel schedule csv
|
| 312 |
+
file to use. The pipeline_parallel_schedule argument
|
| 313 |
+
must be either PipelineScheduleSingle,
|
| 314 |
+
PipelineScheduleMulti, or _PipelineScheduleRuntime.
|
| 315 |
+
--experimental.pipeline_parallel_microbatches EXPERIMENTAL.PIPELINE_PARALLEL_MICROBATCHES
|
| 316 |
+
How many microbatches to split the global training
|
| 317 |
+
batch into when using pipeline parallelism. The global
|
| 318 |
+
training batch size must be evenly divisible by the
|
| 319 |
+
number of microbatches. The default value will be the
|
| 320 |
+
number of pipeline stages, if unspecified.
|
| 321 |
+
--experimental.enable_compiled_autograd
|
| 322 |
+
Enable CompiledAutograd to compile the backward.
|
| 323 |
+
--experimental.context_parallel_degree EXPERIMENTAL.CONTEXT_PARALLEL_DEGREE
|
| 324 |
+
Context parallelism degree. 1 means disabled.
|
| 325 |
+
--experimental.context_parallel_rotate_method EXPERIMENTAL.CONTEXT_PARALLEL_ROTATE_METHOD
|
| 326 |
+
The collective to use in context parallel SDPA for kv
|
| 327 |
+
shards exchange. 'allgather' means to all-gather all
|
| 328 |
+
kv shards on ranks after the first sub-SDPA
|
| 329 |
+
computation, 'alltoall' means to all-to-all shuffle
|
| 330 |
+
the kv shards. The default value is 'allgather'.
|
| 331 |
+
--checkpoint.enable_checkpoint
|
| 332 |
+
Whether to enable checkpoint
|
| 333 |
+
--checkpoint.folder CHECKPOINT.FOLDER
|
| 334 |
+
The folder to store the checkpoints. When
|
| 335 |
+
enable_checkpoint is set to true, checkpoints will be
|
| 336 |
+
in {--job.dump_folder}/{--checkpoint.folder}.
|
| 337 |
+
--checkpoint.interval_type CHECKPOINT.INTERVAL_TYPE
|
| 338 |
+
Checkpointing interval unit of measurement ['step',
|
| 339 |
+
'seconds']
|
| 340 |
+
--checkpoint.interval CHECKPOINT.INTERVAL
|
| 341 |
+
Checkpointing interval, in steps or seconds depending
|
| 342 |
+
on --checkpoint.interval_type
|
| 343 |
+
--checkpoint.model_weights_only
|
| 344 |
+
When model_weights_only=True, only model weights will
|
| 345 |
+
be saved at the end of training. With this,
|
| 346 |
+
checkpoints can be loaded using `torch.load(...,
|
| 347 |
+
weights_only=True)` after conversion. When
|
| 348 |
+
model_weights_only=False, the full checkpoint will be
|
| 349 |
+
saved. A full checkpoint includes model, optimizer and
|
| 350 |
+
train_state, which can be used to resume training. The
|
| 351 |
+
default value is false.
|
| 352 |
+
--checkpoint.export_dtype {float16,bfloat16,float32}
|
| 353 |
+
Converts to the specified precision when training
|
| 354 |
+
completes and model_weights_only=true. Currently
|
| 355 |
+
supports float32, float16, and bfloat16. The default
|
| 356 |
+
value is float32.
|
| 357 |
+
--checkpoint.create_seed_checkpoint
|
| 358 |
+
Initializes the full model without applying
|
| 359 |
+
parallelisms, and then saves it as a seed checkpoint.
|
| 360 |
+
Note: requires user to call train.py without
|
| 361 |
+
specifying any parallelisms, e.g. NGPU=1. Could be
|
| 362 |
+
implemented as a separate script, but this way shares
|
| 363 |
+
more code.
|
| 364 |
+
--checkpoint.async_mode CHECKPOINT.ASYNC_MODE
|
| 365 |
+
Which async checkpoint mode to use. Currently there
|
| 366 |
+
are 3 different modes. 1. "disabled": synchronized
|
| 367 |
+
checkpointing will be used. 2. "async":
|
| 368 |
+
torch.distributed.checkpoint.async_save will be used.
|
| 369 |
+
1. "async_with_pinned_mem": this option utilizes a
|
| 370 |
+
dedicated pinned memory space and creates a separate
|
| 371 |
+
process for faster GPU->CPU transfer performance and
|
| 372 |
+
eliminating GIL contention. The cost is increased CPU
|
| 373 |
+
memory usage. If insufficient CPU memory is available,
|
| 374 |
+
performance may degrade due to memory paging. For most
|
| 375 |
+
users, "async" should suffice as the performance
|
| 376 |
+
overhead is typically small (on the order of tens of
|
| 377 |
+
seconds) compared to checkpointing frequency. This
|
| 378 |
+
mode can be employed to pursue near-zero checkpointing
|
| 379 |
+
times (e.g., < 1 second) given appropriate hardware
|
| 380 |
+
support such as ample CPU memory and fast PCIe.
|
| 381 |
+
"disabled" is the default mode.
|
| 382 |
+
--checkpoint.keep_latest_k CHECKPOINT.KEEP_LATEST_K
|
| 383 |
+
Keeps only the latest k checkpoints, and purging older
|
| 384 |
+
ones. If 0, keep all checkpoints. 0 is the default
|
| 385 |
+
value.
|
| 386 |
+
--checkpoint.load_step CHECKPOINT.LOAD_STEP
|
| 387 |
+
Load the checkpoint at the specified step. If -1, load
|
| 388 |
+
the latest checkpoint.
|
| 389 |
+
--float8.enable_float8_linear
|
| 390 |
+
If true, swaps `torch.nn.Linear` with `Float8Linear`.
|
| 391 |
+
This feature requires you to install 'torchao' which
|
| 392 |
+
can be found here: https://github.com/pytorch/ao
|
| 393 |
+
--float8.enable_fsdp_float8_all_gather
|
| 394 |
+
Whether enable float8 all-gather in FSDP
|
| 395 |
+
--float8.precompute_float8_dynamic_scale_for_fsdp
|
| 396 |
+
Whether precompute float8 scales dynamically for FSDP
|
| 397 |
+
--float8.scaling_type_input {dynamic,delayed}
|
| 398 |
+
float8 scaling for input, dynamic (default) or delayed
|
| 399 |
+
--float8.scaling_type_weight FLOAT8.SCALING_TYPE_WEIGHT
|
| 400 |
+
float8 scaling for input, dynamic (default) or delayed
|
| 401 |
+
--float8.scaling_type_grad_output FLOAT8.SCALING_TYPE_GRAD_OUTPUT
|
| 402 |
+
float8 scaling for input, dynamic (default) or delayed
|
| 403 |
+
--comm.init_timeout_seconds COMM.INIT_TIMEOUT_SECONDS
|
| 404 |
+
Timeout for communication operations, during
|
| 405 |
+
initialization and first train step.
|
| 406 |
+
--comm.train_timeout_seconds COMM.TRAIN_TIMEOUT_SECONDS
|
| 407 |
+
Timeout for communication operations after the first
|
| 408 |
+
train step -- usually a tighter bound than during
|
| 409 |
+
initialization.
|
| 410 |
+
--comm.trace_buf_size COMM.TRACE_BUF_SIZE
|
| 411 |
+
Flight recorder ring buffer size, >0 means recording
|
| 412 |
+
by default, 0 means disabled
|
| 413 |
+
--memory_estimation.enabled
|
| 414 |
+
Whether to estimate memory usage for FSDP
|
| 415 |
+
--memory_estimation.disable_fake_mode
|
| 416 |
+
Whether to estimate memory under FakeTensorMode
|
| 417 |
+
```
|
| 418 |
+
</details>
|
| 419 |
+
|
| 420 |
+
### Training with variable-length inputs
|
| 421 |
+
When you set the `--training.varlen` flag, you're enabling a more efficient training method that packs multiple documents together into a single long sequence, eliminating the need for padding.
|
| 422 |
+
This is particularly useful when your dataset contains documents of varying lengths.
|
| 423 |
+
Let's break down how `--training.seq_len` and `--training.context_len` work in this mode.
|
| 424 |
+
|
| 425 |
+
* `--training.seq_len` (Packed Sequence Length): This is the total length of the final sequence fed to the model on one device. Instead of processing one document at a time, the dataloader takes multiple documents (each split to sequences no longer than `context_len`), concatenates them end-to-end, and creates a single long sequence of length `seq_len`.
|
| 426 |
+
* `--training.context_len` (Sample Length): This parameter defines the maximum number of tokens for a single document or sample. If a document from the dataset is longer than `context_len`, it will be truncated. For example, if `--training.context_len` is set to 4,096, a document with 5,000 tokens will be cut down to its first 4,096 tokens, leaving the left tokens as another independent sequence, while a document with 3000 tokens remains unchanged.
|
| 427 |
+
|
| 428 |
+
### Training with `torch.compile`
|
| 429 |
+
|
| 430 |
+
Starting from `torch 2.0`, `torch.compile` has been introduced as a new feature to seamlessly accelerate training processes.
|
| 431 |
+
In `flame`, one can simply enable `torch.compile` by adding `--training.compile` flag to your training script.
|
| 432 |
+
|
| 433 |
+
However, `fla` has integrated numerous fused kernels for acceleration, which may potentially conflict with `torch.compile`.
|
| 434 |
+
We are actively working on resolving these issues to make compilation transparent to users.
|
| 435 |
+
In the meantime, please ensure you are using the latest dependencies.
|
| 436 |
+
|
| 437 |
+
Specifically, **we recommend using `torch>=2.6` and `triton>=3.0`**.
|
| 438 |
+
|
| 439 |
+
### Training with multiple datasets
|
| 440 |
+
|
| 441 |
+
If you wish to train a model with all-round capabilities (e.g., code, math, and multilingual ability), it's necessary to train on multiple datasets.
|
| 442 |
+
`flame` allows training with multiple datasets easily.
|
| 443 |
+
For example, you can specify the following arguments to train on 6 datasets with different proportions:
|
| 444 |
+
|
| 445 |
+
```sh
|
| 446 |
+
--training.dataset HuggingFaceFW/fineweb-edu,opencsg/Fineweb-Edu-Chinese-V2.1,OpenCoder-LLM/opc-fineweb-code-corpus,math-ai/AutoMathText,EleutherAI/proof-pile-2,OpenCoder-LLM/opc-fineweb-math-corpus \
|
| 447 |
+
--training.data_probs 0.6,0.15,0.15,0.014,0.058,0.028 \
|
| 448 |
+
```
|
| 449 |
+
|
| 450 |
+
### ~Finalizing training~
|
| 451 |
+
|
| 452 |
+
> [!NOTE]
|
| 453 |
+
> We have done this conversion automatically in the training script since our latest updates.
|
| 454 |
+
|
| 455 |
+
Once training is complete, you may want to convert the distributed checkpoints (DCPs) into the 🤗 format for broader use.
|
| 456 |
+
To facilitate this, we provide a straightforward conversion script:
|
| 457 |
+
|
| 458 |
+
```sh
|
| 459 |
+
python -m flame.utils.convert_dcp_to_hf --path <path_to_model> --step <step> --config <path_to_config> --tokenizer <path_to_tokenizer>
|
| 460 |
+
```
|
| 461 |
+
After this, your model will be in the 🤗 format, ready to be shared or deployed.
|
| 462 |
+
You can then easily publish your model using the `huggingface_hub` for wider accessibility.
|
| 463 |
+
|
| 464 |
+
### Continual training
|
| 465 |
+
|
| 466 |
+
If you wish to build upon a strong pre-trained model (in 🤗 format) and continue training, we also offer a script to convert the 🤗 format model back into DCP format.
|
| 467 |
+
This allows you to seamlessly resume training with `flame`.
|
| 468 |
+
```sh
|
| 469 |
+
python -m flame.utils.convert_hf_to_dcp --model <path_to_hf> --checkpoint <path_to_dcp/checkpoint/step-0>
|
| 470 |
+
```
|
| 471 |
+
Here, `<path_to_dcp>` is the directory where your distributed checkpoints will be stored.
|
| 472 |
+
The checkpoint is intentionally saved at `<step-0>` within the checkpoint folder to ensure it is loadable by `flame` during the initial training step, similar to how a seed checkpoint is handled.
|
| 473 |
+
|
| 474 |
+
Once the conversion is complete, you can proceed with training using `flame` as usual, continuing from where the pretrained model left off.
|
| 475 |
+
|
| 476 |
+
## Multi-node training
|
| 477 |
+
|
| 478 |
+
If you have access to multi-node GPUs, consider leveraging them for optimal performance.
|
| 479 |
+
This process is straightforward and well-documented in the PyTorch [docs](https://pytorch.org/docs/stable/elastic/run.html).
|
| 480 |
+
|
| 481 |
+
To set up multi-node training:
|
| 482 |
+
* Set the environment variables `MASTER_ADDR=<ip>` and `MASTER_PORT=<port>` before running the training script across all nodes.
|
| 483 |
+
* If you're using a job scheduler like Slurm, it will handle these variables for you.
|
| 484 |
+
|
| 485 |
+
`torchtitan` provides a [Slurm script](https://github.com/pytorch/torchtitan/blob/main/multinode_trainer.slurm) for multi-node training, which you can use as a reference or starting point.
|
| 486 |
+
|
| 487 |
+
## Custom models
|
| 488 |
+
|
| 489 |
+
`flame` supports custom model architectures through seamless integration with the Hugging Face `transformers` library. To add your own model:
|
| 490 |
+
|
| 491 |
+
1. Create a new model directory under `custom_models/` (see `custom_models/sba` for a complete example)
|
| 492 |
+
2. Implement your model classes and configuration:
|
| 493 |
+
- Define a config class inheriting from `PretrainedConfig` (see `custom_models/sba/config_sba.py` for an example)
|
| 494 |
+
- Create model classes inheriting from `PreTrainedModel` (see `custom_models/sba/modeling_sba.py` for an example)
|
| 495 |
+
3. Register your models in `__init__.py`:
|
| 496 |
+
- Import your model classes and config classes
|
| 497 |
+
- Register your models with the `AutoModelForCausalLM`, `AutoModel` and `AutoConfig` classes (see `custom_models/sba/__init__.py` for an example)
|
| 498 |
+
4. Create a config file for your custom model, just need to specify the `model_type` to the one you just named for your custom model (example: `configs/sba_340m.json`).
|
| 499 |
+
5. Training is extremely simple, you can just use the `flame.train.py` script to train your custom model.
|
| 500 |
+
|
| 501 |
+
|
| 502 |
+
|
| 503 |
+
|
| 504 |
+
|
| 505 |
+
|
| 506 |
+
|
| 507 |
+
## Citation
|
| 508 |
+
|
| 509 |
+
If you find `flame` helpful for your work, please consider citing it.
|
| 510 |
+
|
| 511 |
+
```bib
|
| 512 |
+
@software{yang2025flame,
|
| 513 |
+
title = {Flame: Flash Language Modeling Made Easy},
|
| 514 |
+
author = {Zhang, Yu and Yang, Songlin},
|
| 515 |
+
url = {https://github.com/fla-org/flame},
|
| 516 |
+
month = jan,
|
| 517 |
+
year = {2025}
|
| 518 |
+
}
|
| 519 |
+
```
|
1b_archs_fwe/transformer_1b_fwe_lion_lr3e_4_b1_0_9_b2_0_99_20260519_061741/exp_data/checkpoint/step-30720/.metadata
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:5ee5637775addc1d2413c1529f93a45a5dc2c691a3b8ecce1e341c87104b2bc0
|
| 3 |
+
size 614499
|
1b_archs_fwe/transformer_1b_fwe_lion_lr3e_4_b1_0_9_b2_0_99_20260519_061741/exp_data/checkpoint/step-30720/__0_0.distcp
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:4cd4780ad243c7a639e1ceb0bd3f4e48305f1cf4ca8b2045f6580a3cf475e2bd
|
| 3 |
+
size 1585679607
|
1b_archs_fwe/transformer_1b_fwe_lion_lr3e_4_b1_0_9_b2_0_99_20260519_061741/exp_data/checkpoint/step-30720/__1_0.distcp
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:df67dd042a83a7025e0d30c9447fb256e079396dccc8faac0dc73ebbcf2e9210
|
| 3 |
+
size 1363429789
|
1b_archs_fwe/transformer_1b_fwe_lion_lr3e_4_b1_0_9_b2_0_99_20260519_061741/exp_data/checkpoint/step-30720/__2_0.distcp
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:dcfd24e2a2131d1bb8509b2d568c21922f4cd283f95ccd9cb45ddb645a18aea0
|
| 3 |
+
size 1336900942
|
1b_archs_fwe/transformer_1b_fwe_lion_lr3e_4_b1_0_9_b2_0_99_20260519_061741/exp_data/checkpoint/step-30720/__3_0.distcp
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:6887615ea2e51f81acbe1a993188ae987f6ab668004fa8171504baf81ae2ec8c
|
| 3 |
+
size 1362992323
|
1b_archs_fwe/transformer_1b_fwe_lion_lr3e_4_b1_0_9_b2_0_99_20260519_061741/exp_data/checkpoint/step-30720/__4_0.distcp
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:e6cc1aa227872b9bed2a7b2cafe45ada0daf9925c44570056dd5d0c52581ef33
|
| 3 |
+
size 1336843981
|
1b_archs_fwe/transformer_1b_fwe_lion_lr3e_4_b1_0_9_b2_0_99_20260519_061741/exp_data/checkpoint/step-30720/__5_0.distcp
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:5a6e283af458d723a1cb23b68d98c3b69f780b139bcb323b557529b4fff75a2d
|
| 3 |
+
size 1326186430
|
1b_archs_fwe/transformer_1b_fwe_lion_lr3e_4_b1_0_9_b2_0_99_20260519_061741/exp_data/checkpoint/step-30720/__6_0.distcp
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:0db3d563ce31caa4da8d8fbd51b84e25c1aa0a7987b602b3951e727f8cfd3396
|
| 3 |
+
size 1363165665
|
1b_archs_fwe/transformer_1b_fwe_lion_lr3e_4_b1_0_9_b2_0_99_20260519_061741/exp_data/checkpoint/step-30720/__7_0.distcp
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:45adc5132fc26012a29c89a1c97de44cc45c3be9ed589a0ab71f67a72216b8ea
|
| 3 |
+
size 1331199796
|
1b_archs_fwe/transformer_1b_fwe_lion_lr3e_4_b1_0_9_b2_0_99_20260519_061741/exp_data/config.json
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"architectures": [
|
| 3 |
+
"TransformerForCausalLM"
|
| 4 |
+
],
|
| 5 |
+
"bos_token_id": 1,
|
| 6 |
+
"dtype": "float32",
|
| 7 |
+
"elementwise_affine": true,
|
| 8 |
+
"eos_token_id": 2,
|
| 9 |
+
"fuse_cross_entropy": true,
|
| 10 |
+
"fuse_linear_cross_entropy": false,
|
| 11 |
+
"fuse_norm": true,
|
| 12 |
+
"fuse_swiglu": true,
|
| 13 |
+
"hidden_act": "swish",
|
| 14 |
+
"hidden_ratio": 4,
|
| 15 |
+
"hidden_size": 2048,
|
| 16 |
+
"initializer_range": 0.02,
|
| 17 |
+
"intermediate_size": null,
|
| 18 |
+
"max_position_embeddings": 8192,
|
| 19 |
+
"model_type": "transformer",
|
| 20 |
+
"norm_eps": 1e-06,
|
| 21 |
+
"num_heads": 32,
|
| 22 |
+
"num_hidden_layers": 24,
|
| 23 |
+
"num_kv_heads": null,
|
| 24 |
+
"pad_token_id": 2,
|
| 25 |
+
"qk_norm": false,
|
| 26 |
+
"qkv_bias": false,
|
| 27 |
+
"rope_theta": 10000.0,
|
| 28 |
+
"tie_word_embeddings": false,
|
| 29 |
+
"transformers_version": "4.57.6",
|
| 30 |
+
"use_cache": true,
|
| 31 |
+
"use_l2warp": false,
|
| 32 |
+
"vocab_size": 32000,
|
| 33 |
+
"window_size": null
|
| 34 |
+
}
|
1b_archs_fwe/transformer_1b_fwe_lion_lr3e_4_b1_0_9_b2_0_99_20260519_061741/exp_data/configs/blt_transformer_1000hash.json
ADDED
|
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"model_type": "blt",
|
| 3 |
+
"vocab_size": 260,
|
| 4 |
+
"max_position_embeddings": 4096,
|
| 5 |
+
"initializer_range": 0.02,
|
| 6 |
+
"tie_word_embeddings": false,
|
| 7 |
+
"patch_in_forward": true,
|
| 8 |
+
"patch_size": 4,
|
| 9 |
+
"patching_mode": "entropy",
|
| 10 |
+
"patching_threshold": 1.335442066192627,
|
| 11 |
+
"patching_batch_size": 1,
|
| 12 |
+
"max_patch_length": null,
|
| 13 |
+
"patching_device": "cuda",
|
| 14 |
+
"realtime_patching": true,
|
| 15 |
+
"patching_threshold_add": null,
|
| 16 |
+
"monotonicity": false,
|
| 17 |
+
"cross_attn_k": 2,
|
| 18 |
+
"encoder_hash_byte_group_size": [3, 4, 5, 6, 7, 8],
|
| 19 |
+
"encoder_hash_byte_group_vocab": 1000,
|
| 20 |
+
"encoder_hash_byte_group_nb_functions": 1,
|
| 21 |
+
"patcher_config": {
|
| 22 |
+
"model_type": "blt_patcher",
|
| 23 |
+
"vocab_size": 260,
|
| 24 |
+
"hidden_size": 512,
|
| 25 |
+
"num_hidden_layers": 7,
|
| 26 |
+
"num_attention_heads": 8,
|
| 27 |
+
"num_key_value_heads": 8,
|
| 28 |
+
"max_position_embeddings": 8192,
|
| 29 |
+
"rms_norm_eps": 1e-5,
|
| 30 |
+
"dropout": 0.0,
|
| 31 |
+
"intermediate_size": 1365,
|
| 32 |
+
"hidden_act": "silu",
|
| 33 |
+
"initializer_range": 0.02,
|
| 34 |
+
"rope_parameters": {"rope_type": "default",
|
| 35 |
+
"rope_theta": 500000
|
| 36 |
+
}
|
| 37 |
+
},
|
| 38 |
+
"encoder_config": {
|
| 39 |
+
"model_type": "blt_local_encoder",
|
| 40 |
+
"vocab_size": 260,
|
| 41 |
+
"hidden_size": 512,
|
| 42 |
+
"hidden_size_global": 1024,
|
| 43 |
+
"num_hidden_layers": 1,
|
| 44 |
+
"num_attention_heads": 8,
|
| 45 |
+
"num_key_value_heads": 8,
|
| 46 |
+
"head_dim": 64,
|
| 47 |
+
"intermediate_size": 1365,
|
| 48 |
+
"rms_norm_eps": 1e-5,
|
| 49 |
+
"dropout": 0.0,
|
| 50 |
+
"max_position_embeddings": 24576,
|
| 51 |
+
"cross_attn_all_layers": false,
|
| 52 |
+
"cross_attn_k": 2,
|
| 53 |
+
"hidden_act": "silu",
|
| 54 |
+
"initializer_range": 0.02,
|
| 55 |
+
"rope_parameters": {"rope_type": "default",
|
| 56 |
+
"rope_theta": 500000
|
| 57 |
+
}
|
| 58 |
+
},
|
| 59 |
+
"decoder_config": {
|
| 60 |
+
"model_type": "blt_local_decoder",
|
| 61 |
+
"vocab_size": 260,
|
| 62 |
+
"hidden_size": 512,
|
| 63 |
+
"hidden_size_global": 1024,
|
| 64 |
+
"num_hidden_layers": 9,
|
| 65 |
+
"num_attention_heads": 8,
|
| 66 |
+
"num_key_value_heads": 8,
|
| 67 |
+
"head_dim": 64,
|
| 68 |
+
"intermediate_size": 1365,
|
| 69 |
+
"rms_norm_eps": 1e-5,
|
| 70 |
+
"dropout": 0.0,
|
| 71 |
+
"max_position_embeddings": 24576,
|
| 72 |
+
"cross_attn_all_layers": true,
|
| 73 |
+
"cross_attn_k": 2,
|
| 74 |
+
"hidden_act": "silu",
|
| 75 |
+
"initializer_range": 0.02,
|
| 76 |
+
"rope_parameters": {"rope_type": "default",
|
| 77 |
+
"rope_theta": 500000
|
| 78 |
+
}
|
| 79 |
+
},
|
| 80 |
+
"global_config": {
|
| 81 |
+
"model_type": "blt_global_transformer",
|
| 82 |
+
"hidden_size": 1024,
|
| 83 |
+
"num_hidden_layers": 25,
|
| 84 |
+
"num_attention_heads": 8,
|
| 85 |
+
"num_key_value_heads": 8,
|
| 86 |
+
"head_dim": 128,
|
| 87 |
+
"intermediate_size": 2731,
|
| 88 |
+
"rms_norm_eps": 1e-5,
|
| 89 |
+
"dropout": 0.0,
|
| 90 |
+
"max_position_embeddings": 4096,
|
| 91 |
+
"hidden_act": "silu",
|
| 92 |
+
"initializer_range": 0.02,
|
| 93 |
+
"rope_parameters": {"rope_type": "default",
|
| 94 |
+
"rope_theta": 500000
|
| 95 |
+
},
|
| 96 |
+
"encoder_cross_output_size": null
|
| 97 |
+
}
|
| 98 |
+
}
|
1b_archs_fwe/transformer_1b_fwe_lion_lr3e_4_b1_0_9_b2_0_99_20260519_061741/exp_data/configs/blt_transformer_1_5B.json
ADDED
|
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"model_type": "blt",
|
| 3 |
+
"vocab_size": 260,
|
| 4 |
+
"max_position_embeddings": 4096,
|
| 5 |
+
"initializer_range": 0.02,
|
| 6 |
+
"tie_word_embeddings": false,
|
| 7 |
+
"patch_in_forward": true,
|
| 8 |
+
"patch_size": 4,
|
| 9 |
+
"patching_mode": "entropy",
|
| 10 |
+
"patching_threshold": 1.335442066192627,
|
| 11 |
+
"patching_batch_size": 1,
|
| 12 |
+
"max_patch_length": null,
|
| 13 |
+
"patching_device": "cuda",
|
| 14 |
+
"realtime_patching": true,
|
| 15 |
+
"patching_threshold_add": null,
|
| 16 |
+
"monotonicity": false,
|
| 17 |
+
"cross_attn_k": 2,
|
| 18 |
+
"encoder_hash_byte_group_size": [3, 4, 5, 6, 7, 8],
|
| 19 |
+
"encoder_hash_byte_group_vocab": 500,
|
| 20 |
+
"encoder_hash_byte_group_nb_functions": 1,
|
| 21 |
+
"patcher_config": {
|
| 22 |
+
"model_type": "blt_patcher",
|
| 23 |
+
"vocab_size": 260,
|
| 24 |
+
"hidden_size": 768,
|
| 25 |
+
"num_hidden_layers": 7,
|
| 26 |
+
"num_attention_heads": 12,
|
| 27 |
+
"num_key_value_heads": 12,
|
| 28 |
+
"max_position_embeddings": 8192,
|
| 29 |
+
"rms_norm_eps": 1e-5,
|
| 30 |
+
"dropout": 0.0,
|
| 31 |
+
"intermediate_size": 2048,
|
| 32 |
+
"hidden_act": "silu",
|
| 33 |
+
"initializer_range": 0.02,
|
| 34 |
+
"rope_parameters": {"rope_type": "default",
|
| 35 |
+
"rope_theta": 500000
|
| 36 |
+
}
|
| 37 |
+
},
|
| 38 |
+
"encoder_config": {
|
| 39 |
+
"model_type": "blt_local_encoder",
|
| 40 |
+
"vocab_size": 260,
|
| 41 |
+
"hidden_size": 1024,
|
| 42 |
+
"hidden_size_global": 2048,
|
| 43 |
+
"num_hidden_layers": 1,
|
| 44 |
+
"num_attention_heads": 16,
|
| 45 |
+
"num_key_value_heads": 16,
|
| 46 |
+
"head_dim": 64,
|
| 47 |
+
"intermediate_size": 2816,
|
| 48 |
+
"rms_norm_eps": 1e-5,
|
| 49 |
+
"dropout": 0.0,
|
| 50 |
+
"max_position_embeddings": 24576,
|
| 51 |
+
"cross_attn_all_layers": false,
|
| 52 |
+
"cross_attn_k": 2,
|
| 53 |
+
"hidden_act": "silu",
|
| 54 |
+
"initializer_range": 0.02,
|
| 55 |
+
"rope_parameters": {"rope_type": "default",
|
| 56 |
+
"rope_theta": 500000
|
| 57 |
+
}
|
| 58 |
+
},
|
| 59 |
+
|
| 60 |
+
"decoder_config": {
|
| 61 |
+
"model_type": "blt_local_decoder",
|
| 62 |
+
"vocab_size": 260,
|
| 63 |
+
"hidden_size": 1024,
|
| 64 |
+
"hidden_size_global": 2048,
|
| 65 |
+
"num_hidden_layers": 9,
|
| 66 |
+
"num_attention_heads": 16,
|
| 67 |
+
"num_key_value_heads": 16,
|
| 68 |
+
"head_dim": 64,
|
| 69 |
+
"intermediate_size": 2816,
|
| 70 |
+
"rms_norm_eps": 1e-5,
|
| 71 |
+
"dropout": 0.0,
|
| 72 |
+
"max_position_embeddings": 24576,
|
| 73 |
+
"cross_attn_all_layers": true,
|
| 74 |
+
"cross_attn_k": 2,
|
| 75 |
+
"hidden_act": "silu",
|
| 76 |
+
"initializer_range": 0.02,
|
| 77 |
+
"rope_parameters": {"rope_type": "default",
|
| 78 |
+
"rope_theta": 500000
|
| 79 |
+
}
|
| 80 |
+
},
|
| 81 |
+
"global_config": {
|
| 82 |
+
"model_type": "blt_global_transformer",
|
| 83 |
+
"hidden_size": 2048,
|
| 84 |
+
"num_hidden_layers": 25,
|
| 85 |
+
"num_attention_heads": 16,
|
| 86 |
+
"num_key_value_heads": 16,
|
| 87 |
+
"head_dim": 128,
|
| 88 |
+
"intermediate_size": 5632,
|
| 89 |
+
"rms_norm_eps": 1e-5,
|
| 90 |
+
"dropout": 0.0,
|
| 91 |
+
"max_position_embeddings": 4096,
|
| 92 |
+
"hidden_act": "silu",
|
| 93 |
+
"initializer_range": 0.02,
|
| 94 |
+
"rope_parameters": {"rope_type": "default",
|
| 95 |
+
"rope_theta": 500000
|
| 96 |
+
},
|
| 97 |
+
"encoder_cross_output_size": null
|
| 98 |
+
}
|
| 99 |
+
}
|
1b_archs_fwe/transformer_1b_fwe_lion_lr3e_4_b1_0_9_b2_0_99_20260519_061741/exp_data/configs/blt_transformer_380M.json
ADDED
|
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"model_type": "blt",
|
| 3 |
+
"vocab_size": 260,
|
| 4 |
+
"max_position_embeddings": 4096,
|
| 5 |
+
"initializer_range": 0.02,
|
| 6 |
+
"tie_word_embeddings": false,
|
| 7 |
+
"patch_in_forward": true,
|
| 8 |
+
"patch_size": 4,
|
| 9 |
+
"patching_mode": "entropy",
|
| 10 |
+
"patching_threshold": 1.335442066192627,
|
| 11 |
+
"patching_batch_size": 1,
|
| 12 |
+
"max_patch_length": null,
|
| 13 |
+
"patching_device": "cuda",
|
| 14 |
+
"realtime_patching": true,
|
| 15 |
+
"patching_threshold_add": null,
|
| 16 |
+
"monotonicity": false,
|
| 17 |
+
"cross_attn_k": 2,
|
| 18 |
+
"encoder_hash_byte_group_size": [3, 4, 5, 6, 7, 8],
|
| 19 |
+
"encoder_hash_byte_group_vocab": 500,
|
| 20 |
+
"encoder_hash_byte_group_nb_functions": 1,
|
| 21 |
+
"patcher_config": {
|
| 22 |
+
"model_type": "blt_patcher",
|
| 23 |
+
"vocab_size": 260,
|
| 24 |
+
"hidden_size": 512,
|
| 25 |
+
"num_hidden_layers": 7,
|
| 26 |
+
"num_attention_heads": 8,
|
| 27 |
+
"num_key_value_heads": 8,
|
| 28 |
+
"max_position_embeddings": 8192,
|
| 29 |
+
"rms_norm_eps": 1e-5,
|
| 30 |
+
"dropout": 0.0,
|
| 31 |
+
"intermediate_size": 1365,
|
| 32 |
+
"hidden_act": "silu",
|
| 33 |
+
"initializer_range": 0.02,
|
| 34 |
+
"rope_parameters": {"rope_type": "default",
|
| 35 |
+
"rope_theta": 500000
|
| 36 |
+
}
|
| 37 |
+
},
|
| 38 |
+
"encoder_config": {
|
| 39 |
+
"model_type": "blt_local_encoder",
|
| 40 |
+
"vocab_size": 260,
|
| 41 |
+
"hidden_size": 512,
|
| 42 |
+
"hidden_size_global": 1024,
|
| 43 |
+
"num_hidden_layers": 1,
|
| 44 |
+
"num_attention_heads": 8,
|
| 45 |
+
"num_key_value_heads": 8,
|
| 46 |
+
"head_dim": 64,
|
| 47 |
+
"intermediate_size": 1365,
|
| 48 |
+
"rms_norm_eps": 1e-5,
|
| 49 |
+
"dropout": 0.0,
|
| 50 |
+
"max_position_embeddings": 24576,
|
| 51 |
+
"cross_attn_all_layers": false,
|
| 52 |
+
"cross_attn_k": 2,
|
| 53 |
+
"hidden_act": "silu",
|
| 54 |
+
"initializer_range": 0.02,
|
| 55 |
+
"rope_parameters": {"rope_type": "default",
|
| 56 |
+
"rope_theta": 500000
|
| 57 |
+
}
|
| 58 |
+
},
|
| 59 |
+
"decoder_config": {
|
| 60 |
+
"model_type": "blt_local_decoder",
|
| 61 |
+
"vocab_size": 260,
|
| 62 |
+
"hidden_size": 512,
|
| 63 |
+
"hidden_size_global": 1024,
|
| 64 |
+
"num_hidden_layers": 9,
|
| 65 |
+
"num_attention_heads": 8,
|
| 66 |
+
"num_key_value_heads": 8,
|
| 67 |
+
"head_dim": 64,
|
| 68 |
+
"intermediate_size": 1365,
|
| 69 |
+
"rms_norm_eps": 1e-5,
|
| 70 |
+
"dropout": 0.0,
|
| 71 |
+
"max_position_embeddings": 24576,
|
| 72 |
+
"cross_attn_all_layers": true,
|
| 73 |
+
"cross_attn_k": 2,
|
| 74 |
+
"hidden_act": "silu",
|
| 75 |
+
"initializer_range": 0.02,
|
| 76 |
+
"rope_parameters": {"rope_type": "default",
|
| 77 |
+
"rope_theta": 500000
|
| 78 |
+
}
|
| 79 |
+
},
|
| 80 |
+
"global_config": {
|
| 81 |
+
"model_type": "blt_global_transformer",
|
| 82 |
+
"hidden_size": 1024,
|
| 83 |
+
"num_hidden_layers": 25,
|
| 84 |
+
"num_attention_heads": 8,
|
| 85 |
+
"num_key_value_heads": 8,
|
| 86 |
+
"head_dim": 128,
|
| 87 |
+
"intermediate_size": 2731,
|
| 88 |
+
"rms_norm_eps": 1e-5,
|
| 89 |
+
"dropout": 0.0,
|
| 90 |
+
"max_position_embeddings": 4096,
|
| 91 |
+
"hidden_act": "silu",
|
| 92 |
+
"initializer_range": 0.02,
|
| 93 |
+
"rope_parameters": {"rope_type": "default",
|
| 94 |
+
"rope_theta": 500000
|
| 95 |
+
},
|
| 96 |
+
"encoder_cross_output_size": null
|
| 97 |
+
}
|
| 98 |
+
}
|
1b_archs_fwe/transformer_1b_fwe_lion_lr3e_4_b1_0_9_b2_0_99_20260519_061741/exp_data/configs/delta_net_1B.json
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"attn": null,
|
| 3 |
+
"attn_mode": "chunk",
|
| 4 |
+
"bos_token_id": 1,
|
| 5 |
+
"conv_size": 4,
|
| 6 |
+
"eos_token_id": 2,
|
| 7 |
+
"expand_k": 1,
|
| 8 |
+
"expand_v": 1,
|
| 9 |
+
"fuse_cross_entropy": true,
|
| 10 |
+
"fuse_norm": true,
|
| 11 |
+
"hidden_act": "swish",
|
| 12 |
+
"hidden_ratio": 4,
|
| 13 |
+
"hidden_size": 2048,
|
| 14 |
+
"initializer_range": 0.02,
|
| 15 |
+
"intermediate_size": null,
|
| 16 |
+
"model_type": "delta_net",
|
| 17 |
+
"norm_eps": 1e-06,
|
| 18 |
+
"num_heads": 16,
|
| 19 |
+
"num_hidden_layers": 24,
|
| 20 |
+
"pad_token_id": 2,
|
| 21 |
+
"qk_activation": "silu",
|
| 22 |
+
"qk_norm": "l2",
|
| 23 |
+
"tie_word_embeddings": false,
|
| 24 |
+
"use_beta": true,
|
| 25 |
+
"use_cache": true,
|
| 26 |
+
"use_gate": false,
|
| 27 |
+
"use_output_norm": true,
|
| 28 |
+
"use_short_conv": true
|
| 29 |
+
}
|
1b_archs_fwe/transformer_1b_fwe_lion_lr3e_4_b1_0_9_b2_0_99_20260519_061741/exp_data/configs/delta_net_340M.json
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"attn_mode": "chunk",
|
| 3 |
+
"bos_token_id": 1,
|
| 4 |
+
"conv_size": 4,
|
| 5 |
+
"eos_token_id": 2,
|
| 6 |
+
"expand_k": 1,
|
| 7 |
+
"expand_v": 1,
|
| 8 |
+
"fuse_cross_entropy": true,
|
| 9 |
+
"hidden_act": "swish",
|
| 10 |
+
"hidden_ratio": 4,
|
| 11 |
+
"hidden_size": 1024,
|
| 12 |
+
"initializer_range": 0.02,
|
| 13 |
+
"intermediate_size": null,
|
| 14 |
+
"model_type": "delta_net",
|
| 15 |
+
"norm_eps": 1e-06,
|
| 16 |
+
"num_heads": 8,
|
| 17 |
+
"num_hidden_layers": 24,
|
| 18 |
+
"qk_activation": "silu",
|
| 19 |
+
"qk_norm": "l2",
|
| 20 |
+
"tie_word_embeddings": false,
|
| 21 |
+
"use_beta": true,
|
| 22 |
+
"use_cache": true,
|
| 23 |
+
"use_gate": false,
|
| 24 |
+
"use_output_norm": true,
|
| 25 |
+
"use_short_conv": true
|
| 26 |
+
}
|
1b_archs_fwe/transformer_1b_fwe_lion_lr3e_4_b1_0_9_b2_0_99_20260519_061741/exp_data/configs/gated_deltanet_1B.json
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"attn_mode": "chunk",
|
| 3 |
+
"bos_token_id": 1,
|
| 4 |
+
"conv_size": 4,
|
| 5 |
+
"eos_token_id": 2,
|
| 6 |
+
"expand_v": 2,
|
| 7 |
+
"fuse_cross_entropy": true,
|
| 8 |
+
"head_dim": 256,
|
| 9 |
+
"hidden_act": "swish",
|
| 10 |
+
"hidden_ratio": 4,
|
| 11 |
+
"hidden_size": 2048,
|
| 12 |
+
"initializer_range": 0.02,
|
| 13 |
+
"intermediate_size": null,
|
| 14 |
+
"model_type": "gated_deltanet",
|
| 15 |
+
"norm_eps": 1e-06,
|
| 16 |
+
"num_heads": 6,
|
| 17 |
+
"num_hidden_layers": 21,
|
| 18 |
+
"tie_word_embeddings": false,
|
| 19 |
+
"use_cache": true,
|
| 20 |
+
"use_gate": true,
|
| 21 |
+
"use_short_conv": true
|
| 22 |
+
}
|
1b_archs_fwe/transformer_1b_fwe_lion_lr3e_4_b1_0_9_b2_0_99_20260519_061741/exp_data/configs/gated_deltanet_340M.json
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"attn_mode": "chunk",
|
| 3 |
+
"bos_token_id": 1,
|
| 4 |
+
"conv_size": 4,
|
| 5 |
+
"eos_token_id": 2,
|
| 6 |
+
"expand_v": 2,
|
| 7 |
+
"fuse_cross_entropy": true,
|
| 8 |
+
"head_dim": 256,
|
| 9 |
+
"hidden_act": "swish",
|
| 10 |
+
"hidden_ratio": 4,
|
| 11 |
+
"hidden_size": 1024,
|
| 12 |
+
"initializer_range": 0.02,
|
| 13 |
+
"intermediate_size": null,
|
| 14 |
+
"model_type": "gated_deltanet",
|
| 15 |
+
"norm_eps": 1e-06,
|
| 16 |
+
"num_heads": 6,
|
| 17 |
+
"num_hidden_layers": 21,
|
| 18 |
+
"tie_word_embeddings": false,
|
| 19 |
+
"use_cache": true,
|
| 20 |
+
"use_gate": true,
|
| 21 |
+
"use_short_conv": true
|
| 22 |
+
}
|
1b_archs_fwe/transformer_1b_fwe_lion_lr3e_4_b1_0_9_b2_0_99_20260519_061741/exp_data/configs/gated_deltanet_h_340M.json
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"model_type": "gated_deltanet",
|
| 3 |
+
"attn_mode": "chunk",
|
| 4 |
+
"hidden_size": 1024,
|
| 5 |
+
"num_hidden_layers": 21,
|
| 6 |
+
"head_dim": 256,
|
| 7 |
+
"num_heads": 6,
|
| 8 |
+
"expand_v": 2,
|
| 9 |
+
"hidden_ratio": 4,
|
| 10 |
+
"use_gate": true,
|
| 11 |
+
"use_short_conv": true,
|
| 12 |
+
"conv_size": 4,
|
| 13 |
+
"vocab_size": 32000,
|
| 14 |
+
"hidden_act": "swish",
|
| 15 |
+
"norm_eps": 1e-06,
|
| 16 |
+
"bos_token_id": 1,
|
| 17 |
+
"eos_token_id": 2,
|
| 18 |
+
"fuse_cross_entropy": true,
|
| 19 |
+
"initializer_range": 0.02,
|
| 20 |
+
"attn": {
|
| 21 |
+
"layers": [3, 7, 11, 15, 19],
|
| 22 |
+
"num_heads": 8,
|
| 23 |
+
"num_kv_heads": 1,
|
| 24 |
+
"window_size": 2048,
|
| 25 |
+
"rope_theta": 100000.0,
|
| 26 |
+
"qkv_bias": false
|
| 27 |
+
}
|
| 28 |
+
}
|
1b_archs_fwe/transformer_1b_fwe_lion_lr3e_4_b1_0_9_b2_0_99_20260519_061741/exp_data/configs/gla_1B.json
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"attn_mode": "chunk",
|
| 3 |
+
"bos_token_id": 1,
|
| 4 |
+
"clamp_min": null,
|
| 5 |
+
"eos_token_id": 2,
|
| 6 |
+
"expand_k": 0.5,
|
| 7 |
+
"expand_v": 1,
|
| 8 |
+
"fuse_cross_entropy": true,
|
| 9 |
+
"fuse_norm": true,
|
| 10 |
+
"hidden_act": "swish",
|
| 11 |
+
"hidden_ratio": 4,
|
| 12 |
+
"hidden_size": 2048,
|
| 13 |
+
"initializer_range": 0.02,
|
| 14 |
+
"intermediate_size": null,
|
| 15 |
+
"model_type": "gla",
|
| 16 |
+
"num_heads": 4,
|
| 17 |
+
"num_hidden_layers": 24,
|
| 18 |
+
"norm_eps": 1e-06,
|
| 19 |
+
"tie_word_embeddings": false,
|
| 20 |
+
"use_cache": true,
|
| 21 |
+
"use_gk": true,
|
| 22 |
+
"use_gv": false,
|
| 23 |
+
"vocab_size": 32000
|
| 24 |
+
}
|
1b_archs_fwe/transformer_1b_fwe_lion_lr3e_4_b1_0_9_b2_0_99_20260519_061741/exp_data/configs/gla_340M.json
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"attn_mode": "chunk",
|
| 3 |
+
"bos_token_id": 1,
|
| 4 |
+
"clamp_min": null,
|
| 5 |
+
"eos_token_id": 2,
|
| 6 |
+
"expand_k": 0.5,
|
| 7 |
+
"expand_v": 1,
|
| 8 |
+
"fuse_cross_entropy": true,
|
| 9 |
+
"fuse_norm": true,
|
| 10 |
+
"hidden_act": "swish",
|
| 11 |
+
"hidden_ratio": 4,
|
| 12 |
+
"hidden_size": 1024,
|
| 13 |
+
"initializer_range": 0.02,
|
| 14 |
+
"intermediate_size": null,
|
| 15 |
+
"model_type": "gla",
|
| 16 |
+
"num_heads": 4,
|
| 17 |
+
"num_hidden_layers": 24,
|
| 18 |
+
"norm_eps": 1e-06,
|
| 19 |
+
"tie_word_embeddings": false,
|
| 20 |
+
"use_cache": true,
|
| 21 |
+
"use_gk": true,
|
| 22 |
+
"use_gv": false,
|
| 23 |
+
"vocab_size": 32000
|
| 24 |
+
}
|
1b_archs_fwe/transformer_1b_fwe_lion_lr3e_4_b1_0_9_b2_0_99_20260519_061741/exp_data/configs/gla_7B.json
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"attn": null,
|
| 3 |
+
"attn_mode": "chunk",
|
| 4 |
+
"bos_token_id": 1,
|
| 5 |
+
"eos_token_id": 2,
|
| 6 |
+
"expand_k": 0.5,
|
| 7 |
+
"expand_v": 1,
|
| 8 |
+
"fuse_cross_entropy": true,
|
| 9 |
+
"fuse_norm": true,
|
| 10 |
+
"hidden_act": "swish",
|
| 11 |
+
"hidden_ratio": 4,
|
| 12 |
+
"hidden_size": 4096,
|
| 13 |
+
"initializer_range": 0.02,
|
| 14 |
+
"intermediate_size": 11008,
|
| 15 |
+
"model_type": "gla",
|
| 16 |
+
"norm_eps": 1e-06,
|
| 17 |
+
"num_heads": 16,
|
| 18 |
+
"num_hidden_layers": 32,
|
| 19 |
+
"tie_word_embeddings": false,
|
| 20 |
+
"use_cache": true,
|
| 21 |
+
"use_gk": true,
|
| 22 |
+
"use_gv": false,
|
| 23 |
+
"use_output_gate": true,
|
| 24 |
+
"use_short_conv": false
|
| 25 |
+
}
|
1b_archs_fwe/transformer_1b_fwe_lion_lr3e_4_b1_0_9_b2_0_99_20260519_061741/exp_data/configs/gsa_340M.json
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"bos_token_id": 1,
|
| 3 |
+
"conv_size": 4,
|
| 4 |
+
"eos_token_id": 2,
|
| 5 |
+
"expand_k": 1,
|
| 6 |
+
"expand_v": 1,
|
| 7 |
+
"elementwise_affine": false,
|
| 8 |
+
"feature_map": "swish",
|
| 9 |
+
"fuse_cross_entropy": true,
|
| 10 |
+
"fuse_norm": true,
|
| 11 |
+
"gate_logit_normalizer": 4,
|
| 12 |
+
"hidden_act": "swish",
|
| 13 |
+
"hidden_ratio": 4,
|
| 14 |
+
"hidden_size": 1024,
|
| 15 |
+
"initializer_range": 0.02,
|
| 16 |
+
"intermediate_size": null,
|
| 17 |
+
"model_type": "gsa",
|
| 18 |
+
"num_heads": 4,
|
| 19 |
+
"num_hidden_layers": 24,
|
| 20 |
+
"num_slots": 64,
|
| 21 |
+
"norm_eps": 1e-06,
|
| 22 |
+
"share_conv_kernel": true,
|
| 23 |
+
"tie_word_embeddings": false,
|
| 24 |
+
"use_cache": true,
|
| 25 |
+
"use_norm": true,
|
| 26 |
+
"use_output_gate": true,
|
| 27 |
+
"use_rope": false,
|
| 28 |
+
"use_short_conv": false
|
| 29 |
+
}
|
1b_archs_fwe/transformer_1b_fwe_lion_lr3e_4_b1_0_9_b2_0_99_20260519_061741/exp_data/configs/mergenet_340M.json
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"model_type": "mergenet",
|
| 3 |
+
"vocab_size": 260,
|
| 4 |
+
"hidden_size": 1024,
|
| 5 |
+
"num_local_layers": 6,
|
| 6 |
+
"local_depth": 4,
|
| 7 |
+
"num_latent_layers": 12,
|
| 8 |
+
"num_heads": 16,
|
| 9 |
+
"num_kv_heads": 16,
|
| 10 |
+
"intermediate_size": 4096,
|
| 11 |
+
"hidden_act": "swish",
|
| 12 |
+
"max_position_embeddings": 8192,
|
| 13 |
+
"lambda_local": 4.0,
|
| 14 |
+
"dtem_window_size": 8,
|
| 15 |
+
"dtem_t": 1,
|
| 16 |
+
"dtem_feat_dim": null,
|
| 17 |
+
"use_softkmax": false,
|
| 18 |
+
"grid_bias_gamma": 1.0,
|
| 19 |
+
"W_infer": null,
|
| 20 |
+
"qkv_bias": true,
|
| 21 |
+
"qk_norm": false,
|
| 22 |
+
"rope_theta": 10000.0,
|
| 23 |
+
"norm_eps": 1e-6,
|
| 24 |
+
"initializer_range": 0.02,
|
| 25 |
+
"use_cache": true,
|
| 26 |
+
"pad_token_id": 0,
|
| 27 |
+
"bos_token_id": 1,
|
| 28 |
+
"eos_token_id": 2,
|
| 29 |
+
"tie_word_embeddings": false,
|
| 30 |
+
"phase": "phase2",
|
| 31 |
+
"drop_rate": 0.0,
|
| 32 |
+
"attn_drop_rate": 0.0,
|
| 33 |
+
"drop_path_rate": 0.1
|
| 34 |
+
}
|
1b_archs_fwe/transformer_1b_fwe_lion_lr3e_4_b1_0_9_b2_0_99_20260519_061741/exp_data/configs/mergenet_64M.json
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"model_type": "mergenet",
|
| 3 |
+
"vocab_size": 32000,
|
| 4 |
+
"hidden_size": 512,
|
| 5 |
+
"num_local_layers": 4,
|
| 6 |
+
"local_depth": 4,
|
| 7 |
+
"num_latent_layers": 8,
|
| 8 |
+
"num_heads": 8,
|
| 9 |
+
"num_kv_heads": 8,
|
| 10 |
+
"intermediate_size": 2048,
|
| 11 |
+
"hidden_act": "swish",
|
| 12 |
+
"max_position_embeddings": 4096,
|
| 13 |
+
"lambda_local": 4.0,
|
| 14 |
+
"dtem_window_size": 8,
|
| 15 |
+
"dtem_t": 1,
|
| 16 |
+
"dtem_feat_dim": null,
|
| 17 |
+
"use_softkmax": false,
|
| 18 |
+
"grid_bias_gamma": 1.0,
|
| 19 |
+
"W_infer": null,
|
| 20 |
+
"qkv_bias": true,
|
| 21 |
+
"qk_norm": false,
|
| 22 |
+
"rope_theta": 10000.0,
|
| 23 |
+
"norm_eps": 1e-6,
|
| 24 |
+
"initializer_range": 0.02,
|
| 25 |
+
"use_cache": true,
|
| 26 |
+
"pad_token_id": 0,
|
| 27 |
+
"bos_token_id": 1,
|
| 28 |
+
"eos_token_id": 2,
|
| 29 |
+
"tie_word_embeddings": false,
|
| 30 |
+
"phase": "phase2",
|
| 31 |
+
"drop_rate": 0.0,
|
| 32 |
+
"attn_drop_rate": 0.0,
|
| 33 |
+
"drop_path_rate": 0.1
|
| 34 |
+
}
|
1b_archs_fwe/transformer_1b_fwe_lion_lr3e_4_b1_0_9_b2_0_99_20260519_061741/exp_data/configs/qwen3_next_1B.json
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"model_type": "qwen3_next",
|
| 3 |
+
"vocab_size": 151936,
|
| 4 |
+
"hidden_size": 2048,
|
| 5 |
+
"intermediate_size": 5632,
|
| 6 |
+
"num_hidden_layers": 48,
|
| 7 |
+
"num_attention_heads": 16,
|
| 8 |
+
"num_key_value_heads": 2,
|
| 9 |
+
"head_dim": 256,
|
| 10 |
+
"hidden_act": "silu",
|
| 11 |
+
"max_position_embeddings": 32768,
|
| 12 |
+
"initializer_range": 0.02,
|
| 13 |
+
"rms_norm_eps": 1e-6,
|
| 14 |
+
"use_cache": true,
|
| 15 |
+
"tie_word_embeddings": false,
|
| 16 |
+
"attention_bias": false,
|
| 17 |
+
"attention_dropout": 0.0,
|
| 18 |
+
"rope_parameters": {
|
| 19 |
+
"rope_type": "default",
|
| 20 |
+
"factor": 1.0
|
| 21 |
+
},
|
| 22 |
+
"partial_rotary_factor": 0.25,
|
| 23 |
+
"layer_types": [
|
| 24 |
+
"linear_attention",
|
| 25 |
+
"linear_attention",
|
| 26 |
+
"linear_attention",
|
| 27 |
+
"full_attention"
|
| 28 |
+
],
|
| 29 |
+
"linear_conv_kernel_dim": 4,
|
| 30 |
+
"linear_key_head_dim": 128,
|
| 31 |
+
"linear_value_head_dim": 128,
|
| 32 |
+
"linear_num_key_heads": 16,
|
| 33 |
+
"linear_num_value_heads": 32,
|
| 34 |
+
"decoder_sparse_step": 1,
|
| 35 |
+
"moe_intermediate_size": 512,
|
| 36 |
+
"shared_expert_intermediate_size": 512,
|
| 37 |
+
"num_experts_per_tok": 10,
|
| 38 |
+
"num_experts": 512,
|
| 39 |
+
"norm_topk_prob": true,
|
| 40 |
+
"output_router_logits": false,
|
| 41 |
+
"router_aux_loss_coef": 0.001,
|
| 42 |
+
"mlp_only_layers": []
|
| 43 |
+
}
|
| 44 |
+
|
1b_archs_fwe/transformer_1b_fwe_lion_lr3e_4_b1_0_9_b2_0_99_20260519_061741/exp_data/configs/qwen3_next_350M.json
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"model_type": "qwen3_next",
|
| 3 |
+
"vocab_size": 32000,
|
| 4 |
+
"hidden_size": 2048,
|
| 5 |
+
"intermediate_size": 5632,
|
| 6 |
+
"num_hidden_layers": 26,
|
| 7 |
+
"num_attention_heads": 16,
|
| 8 |
+
"num_key_value_heads": 2,
|
| 9 |
+
"head_dim": 256,
|
| 10 |
+
"hidden_act": "silu",
|
| 11 |
+
"max_position_embeddings": 32768,
|
| 12 |
+
"initializer_range": 0.02,
|
| 13 |
+
"rms_norm_eps": 1e-6,
|
| 14 |
+
"use_cache": true,
|
| 15 |
+
"tie_word_embeddings": false,
|
| 16 |
+
"attention_bias": false,
|
| 17 |
+
"attention_dropout": 0.0,
|
| 18 |
+
"rope_parameters": {
|
| 19 |
+
"rope_type": "default",
|
| 20 |
+
"factor": 1.0
|
| 21 |
+
},
|
| 22 |
+
"partial_rotary_factor": 0.25,
|
| 23 |
+
"layer_types": [
|
| 24 |
+
"linear_attention",
|
| 25 |
+
"linear_attention",
|
| 26 |
+
"linear_attention",
|
| 27 |
+
"full_attention"
|
| 28 |
+
],
|
| 29 |
+
"linear_conv_kernel_dim": 4,
|
| 30 |
+
"linear_key_head_dim": 128,
|
| 31 |
+
"linear_value_head_dim": 128,
|
| 32 |
+
"linear_num_key_heads": 16,
|
| 33 |
+
"linear_num_value_heads": 32,
|
| 34 |
+
"decoder_sparse_step": 1,
|
| 35 |
+
"moe_intermediate_size": 512,
|
| 36 |
+
"shared_expert_intermediate_size": 512,
|
| 37 |
+
"num_experts_per_tok": 10,
|
| 38 |
+
"num_experts": 512,
|
| 39 |
+
"norm_topk_prob": true,
|
| 40 |
+
"output_router_logits": false,
|
| 41 |
+
"router_aux_loss_coef": 0.001,
|
| 42 |
+
"mlp_only_layers": []
|
| 43 |
+
}
|
| 44 |
+
|
1b_archs_fwe/transformer_1b_fwe_lion_lr3e_4_b1_0_9_b2_0_99_20260519_061741/exp_data/configs/transformer_1B.json
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"bos_token_id": 1,
|
| 3 |
+
"elementwise_affine": true,
|
| 4 |
+
"eos_token_id": 2,
|
| 5 |
+
"fuse_cross_entropy": true,
|
| 6 |
+
"fuse_norm": true,
|
| 7 |
+
"fuse_swiglu": true,
|
| 8 |
+
"hidden_act": "swish",
|
| 9 |
+
"hidden_ratio": 4,
|
| 10 |
+
"hidden_size": 2048,
|
| 11 |
+
"initializer_range": 0.02,
|
| 12 |
+
"intermediate_size": null,
|
| 13 |
+
"max_position_embeddings": 8192,
|
| 14 |
+
"model_type": "transformer",
|
| 15 |
+
"norm_eps": 1e-06,
|
| 16 |
+
"num_heads": 32,
|
| 17 |
+
"num_hidden_layers": 24,
|
| 18 |
+
"num_kv_heads": null,
|
| 19 |
+
"pad_token_id": 2,
|
| 20 |
+
"rope_theta": 10000.0,
|
| 21 |
+
"tie_word_embeddings": false
|
| 22 |
+
}
|
1b_archs_fwe/transformer_1b_fwe_lion_lr3e_4_b1_0_9_b2_0_99_20260519_061741/exp_data/configs/transformer_340M.json
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"attention_bias": false,
|
| 3 |
+
"bos_token_id": 1,
|
| 4 |
+
"eos_token_id": 2,
|
| 5 |
+
"fuse_cross_entropy": true,
|
| 6 |
+
"fuse_norm": true,
|
| 7 |
+
"hidden_act": "swish",
|
| 8 |
+
"hidden_size": 1024,
|
| 9 |
+
"initializer_range": 0.02,
|
| 10 |
+
"max_position_embeddings": 8192,
|
| 11 |
+
"model_type": "transformer",
|
| 12 |
+
"num_heads": 16,
|
| 13 |
+
"num_hidden_layers": 24,
|
| 14 |
+
"norm_eps": 1e-06,
|
| 15 |
+
"tie_word_embeddings": false,
|
| 16 |
+
"use_cache": true,
|
| 17 |
+
"vocab_size": 32000
|
| 18 |
+
}
|
1b_archs_fwe/transformer_1b_fwe_lion_lr3e_4_b1_0_9_b2_0_99_20260519_061741/exp_data/configs/transformer_7B.json
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"attention_bias": false,
|
| 3 |
+
"bos_token_id": 1,
|
| 4 |
+
"eos_token_id": 2,
|
| 5 |
+
"fuse_cross_entropy": true,
|
| 6 |
+
"fuse_norm": true,
|
| 7 |
+
"hidden_act": "swish",
|
| 8 |
+
"hidden_ratio": 4,
|
| 9 |
+
"hidden_size": 4096,
|
| 10 |
+
"initializer_range": 0.02,
|
| 11 |
+
"intermediate_size": 14336,
|
| 12 |
+
"model_type": "transformer",
|
| 13 |
+
"norm_eps": 1e-06,
|
| 14 |
+
"num_heads": 32,
|
| 15 |
+
"num_hidden_layers": 32,
|
| 16 |
+
"num_kv_heads": 8,
|
| 17 |
+
"rope_theta": 10000.0,
|
| 18 |
+
"tie_word_embeddings": false,
|
| 19 |
+
"use_cache": true,
|
| 20 |
+
"window_size": null
|
| 21 |
+
}
|
1b_archs_fwe/transformer_1b_fwe_lion_lr3e_4_b1_0_9_b2_0_99_20260519_061741/exp_data/flame/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
__version__ = "0.1.0"
|
1b_archs_fwe/transformer_1b_fwe_lion_lr3e_4_b1_0_9_b2_0_99_20260519_061741/exp_data/flame/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (207 Bytes). View file
|
|
|
1b_archs_fwe/transformer_1b_fwe_lion_lr3e_4_b1_0_9_b2_0_99_20260519_061741/exp_data/flame/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (244 Bytes). View file
|
|
|
1b_archs_fwe/transformer_1b_fwe_lion_lr3e_4_b1_0_9_b2_0_99_20260519_061741/exp_data/flame/__pycache__/__init__.cpython-313.pyc
ADDED
|
Binary file (238 Bytes). View file
|
|
|
1b_archs_fwe/transformer_1b_fwe_lion_lr3e_4_b1_0_9_b2_0_99_20260519_061741/exp_data/flame/__pycache__/config_manager.cpython-310.pyc
ADDED
|
Binary file (29.6 kB). View file
|
|
|
1b_archs_fwe/transformer_1b_fwe_lion_lr3e_4_b1_0_9_b2_0_99_20260519_061741/exp_data/flame/__pycache__/config_manager.cpython-311.pyc
ADDED
|
Binary file (41.5 kB). View file
|
|
|
1b_archs_fwe/transformer_1b_fwe_lion_lr3e_4_b1_0_9_b2_0_99_20260519_061741/exp_data/flame/__pycache__/data.cpython-310.pyc
ADDED
|
Binary file (21.7 kB). View file
|
|
|
1b_archs_fwe/transformer_1b_fwe_lion_lr3e_4_b1_0_9_b2_0_99_20260519_061741/exp_data/flame/__pycache__/data.cpython-311.pyc
ADDED
|
Binary file (41.6 kB). View file
|
|
|
1b_archs_fwe/transformer_1b_fwe_lion_lr3e_4_b1_0_9_b2_0_99_20260519_061741/exp_data/flame/__pycache__/train.cpython-310.pyc
ADDED
|
Binary file (19.6 kB). View file
|
|
|
1b_archs_fwe/transformer_1b_fwe_lion_lr3e_4_b1_0_9_b2_0_99_20260519_061741/exp_data/flame/__pycache__/train.cpython-311.pyc
ADDED
|
Binary file (41.2 kB). View file
|
|
|
1b_archs_fwe/transformer_1b_fwe_lion_lr3e_4_b1_0_9_b2_0_99_20260519_061741/exp_data/flame/__pycache__/train.cpython-313.pyc
ADDED
|
Binary file (39.7 kB). View file
|
|
|
1b_archs_fwe/transformer_1b_fwe_lion_lr3e_4_b1_0_9_b2_0_99_20260519_061741/exp_data/flame/c4_test.py
ADDED
|
@@ -0,0 +1,603 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
import os
|
| 7 |
+
import time
|
| 8 |
+
import json
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
import torch.nn as nn
|
| 12 |
+
import torch.distributed as dist
|
| 13 |
+
|
| 14 |
+
from tqdm import tqdm
|
| 15 |
+
from loguru import logger
|
| 16 |
+
|
| 17 |
+
import transformers
|
| 18 |
+
|
| 19 |
+
transformers.logging.set_verbosity_error()
|
| 20 |
+
|
| 21 |
+
import wandb
|
| 22 |
+
|
| 23 |
+
from utils.argparse import parse_args
|
| 24 |
+
from utils.setup import getting_svd_cnt, set_seed, setup_model, saving_model_weight, load_model_weight
|
| 25 |
+
from utils.optimizer_factory import setup_optimization
|
| 26 |
+
from utils.eval import evaluate_model
|
| 27 |
+
from utils.dataloader import setup_dataset
|
| 28 |
+
from utils.modeling_llama import LlamaForCausalLM
|
| 29 |
+
from utils.fake_quantization import QLinear
|
| 30 |
+
from utils.quantization import QScaleLinear
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def main(args):
|
| 34 |
+
import torch
|
| 35 |
+
############ Setup random seed ############
|
| 36 |
+
set_seed(args)
|
| 37 |
+
|
| 38 |
+
############ Setup DDP environment ############
|
| 39 |
+
assert "LOCAL_RANK" in os.environ, "torchrun should set LOCAL_RANK"
|
| 40 |
+
global_rank = int(os.environ["RANK"])
|
| 41 |
+
local_rank = int(os.environ["LOCAL_RANK"])
|
| 42 |
+
world_size = int(os.environ["WORLD_SIZE"])
|
| 43 |
+
torch.cuda.set_device(local_rank)
|
| 44 |
+
|
| 45 |
+
logger.info(f"Global rank {global_rank}, local rank {local_rank}, device: {torch.cuda.current_device()}")
|
| 46 |
+
dist.init_process_group(backend="nccl", rank=global_rank, world_size=world_size)
|
| 47 |
+
|
| 48 |
+
logger.info("Process group initialized")
|
| 49 |
+
device = f"cuda:{local_rank}"
|
| 50 |
+
|
| 51 |
+
if global_rank != 0:
|
| 52 |
+
logger.remove() # turn off logger
|
| 53 |
+
|
| 54 |
+
logger.info(f"Using dist with rank {global_rank} (only rank 0 will log)")
|
| 55 |
+
logger.info("*" * 40)
|
| 56 |
+
logger.info(f"Starting training with the arguments")
|
| 57 |
+
for k, v in vars(args).items():
|
| 58 |
+
logger.info(f"{k:30} {v}")
|
| 59 |
+
logger.info("*" * 40)
|
| 60 |
+
|
| 61 |
+
############ Initialize wandb without config (it is passed later) ############
|
| 62 |
+
if (not args.unset_wandb) and global_rank == 0:
|
| 63 |
+
if args.entity is None:
|
| 64 |
+
os.environ['WANDB_MODE'] = 'offline'
|
| 65 |
+
# Set wandb directory for offline mode
|
| 66 |
+
wandb_dir = getattr(args, 'wandb_dir', None) if getattr(args, 'wandb_dir', None) is not None else args.save_dir
|
| 67 |
+
if getattr(args, 'wandb_dir', None) is not None:
|
| 68 |
+
logger.info(f"Wandb directory set to: {wandb_dir}")
|
| 69 |
+
wandb.init(project=args.project, name=args.name, entity=args.entity, dir=wandb_dir)
|
| 70 |
+
|
| 71 |
+
############ Setup training data ############
|
| 72 |
+
if args.total_batch_size is not None:
|
| 73 |
+
if args.gradient_accumulation is None:
|
| 74 |
+
assert args.total_batch_size % world_size == 0, "total_batch_size must be divisible by world_size"
|
| 75 |
+
args.gradient_accumulation = args.total_batch_size // (args.batch_size * world_size)
|
| 76 |
+
assert args.gradient_accumulation > 0, "gradient_accumulation must be greater than 0"
|
| 77 |
+
|
| 78 |
+
assert (
|
| 79 |
+
args.gradient_accumulation * args.batch_size * world_size == args.total_batch_size
|
| 80 |
+
), "gradient_accumulation * batch_size * world_size must be equal to total_batch_size"
|
| 81 |
+
|
| 82 |
+
dataloader, tokenizer = setup_dataset(args, global_rank, world_size)
|
| 83 |
+
|
| 84 |
+
############ Initialize model ############
|
| 85 |
+
model_config, model = setup_model(args)
|
| 86 |
+
# Ensure model has generation_config (fix for transformers version compatibility)
|
| 87 |
+
if model.generation_config is None:
|
| 88 |
+
from transformers import GenerationConfig
|
| 89 |
+
model.generation_config = GenerationConfig()
|
| 90 |
+
model.generation_config.pad_token_id = tokenizer.pad_token_id
|
| 91 |
+
|
| 92 |
+
############ Resuming from checkpoints ############
|
| 93 |
+
global_step = 0
|
| 94 |
+
update_step = 0
|
| 95 |
+
beginning_step = 0
|
| 96 |
+
tokens_seen = 0
|
| 97 |
+
tokens_seen_before = 0
|
| 98 |
+
|
| 99 |
+
# identifying checkpointing
|
| 100 |
+
if args.continue_from is not None and os.path.exists(args.continue_from):
|
| 101 |
+
# searching the latest checkpoints
|
| 102 |
+
checkpoint_path_list = os.listdir(args.continue_from)
|
| 103 |
+
checkpoint_path_list = [int(x.split("_")[-1]) for x in checkpoint_path_list if x.startswith("model_")]
|
| 104 |
+
if len(checkpoint_path_list) > 0:
|
| 105 |
+
logger.info("Find Checkpoints", checkpoint_path_list)
|
| 106 |
+
beginning_step = max(checkpoint_path_list)
|
| 107 |
+
if args.resume_step is not None:
|
| 108 |
+
beginning_step = args.resume_step
|
| 109 |
+
args.continue_from = os.path.join(args.continue_from, f"model_{beginning_step}")
|
| 110 |
+
logger.info("Continue from", args.continue_from)
|
| 111 |
+
else:
|
| 112 |
+
logger.warning(f"Did not find any checkpoints in {args.continue_from}")
|
| 113 |
+
args.continue_from = None
|
| 114 |
+
|
| 115 |
+
# resuming from checkpointing
|
| 116 |
+
if args.continue_from is not None:
|
| 117 |
+
logger.info("*" * 40)
|
| 118 |
+
logger.info(f"Loading model from {args.continue_from}")
|
| 119 |
+
checkpoint_path = os.path.join(args.continue_from, "pytorch_model.bin")
|
| 120 |
+
if os.path.exists(checkpoint_path):
|
| 121 |
+
load_model_weight(model, checkpoint_path, args)
|
| 122 |
+
logger.info(f"Model successfully loaded (strict=False policy)")
|
| 123 |
+
else:
|
| 124 |
+
# Try safetensors format
|
| 125 |
+
checkpoint_path = os.path.join(args.continue_from, "model.safetensors")
|
| 126 |
+
if os.path.exists(checkpoint_path):
|
| 127 |
+
from safetensors import safe_open
|
| 128 |
+
tensors = {}
|
| 129 |
+
with safe_open(checkpoint_path, framework="pt", device=0) as f:
|
| 130 |
+
for k in f.keys():
|
| 131 |
+
tensors[k] = f.get_tensor(k)
|
| 132 |
+
print(k, tensors[k].shape)
|
| 133 |
+
ret = model.load_state_dict(tensors, strict=False)
|
| 134 |
+
logger.info(f"Model successfully loaded from safetensors (strict=False policy)", ret)
|
| 135 |
+
else:
|
| 136 |
+
logger.warning(f"No model checkpoint found in {args.continue_from}")
|
| 137 |
+
|
| 138 |
+
if os.path.exists(os.path.join(args.continue_from, "training_state.json")):
|
| 139 |
+
logger.info(
|
| 140 |
+
f"Loading training state like global_step, update_step, and tokens_seen from {args.continue_from}"
|
| 141 |
+
)
|
| 142 |
+
with open(os.path.join(args.continue_from, "training_state.json")) as f:
|
| 143 |
+
_old_state = json.load(f)
|
| 144 |
+
global_step = _old_state["global_step"]
|
| 145 |
+
update_step = _old_state["update_step"]
|
| 146 |
+
tokens_seen = _old_state["tokens_seen"]
|
| 147 |
+
tokens_seen_before = _old_state["tokens_seen_before"]
|
| 148 |
+
logger.info(f"global_step : {global_step}")
|
| 149 |
+
logger.info(f"update_step : {update_step}")
|
| 150 |
+
logger.info(f"tokens_seen : {tokens_seen}")
|
| 151 |
+
logger.info(f"tokens_seen_before: {tokens_seen_before}")
|
| 152 |
+
logger.info(f"Will train for {args.num_training_steps - update_step} update steps")
|
| 153 |
+
else:
|
| 154 |
+
logger.warning(f"Did not find training state in {args.continue_from}, global step will start from zero")
|
| 155 |
+
logger.info("*" * 40)
|
| 156 |
+
|
| 157 |
+
############ Setup model ############
|
| 158 |
+
if args.dtype in ["bf16", "bfloat16"]:
|
| 159 |
+
model = model.to(dtype=torch.bfloat16)
|
| 160 |
+
model = model.to(device=device)
|
| 161 |
+
|
| 162 |
+
for _, module in model.named_modules():
|
| 163 |
+
if isinstance(module, QScaleLinear):
|
| 164 |
+
weight_device = module.weight.device
|
| 165 |
+
module.weight.scales = module.weight.scales.to(device=weight_device)
|
| 166 |
+
module.weight.zeros = module.weight.zeros.to(device=weight_device)
|
| 167 |
+
|
| 168 |
+
n_total_params = sum(p.numel() for p in model.parameters())
|
| 169 |
+
trainable_params = [p for p in model.parameters() if p.requires_grad]
|
| 170 |
+
trainable_params_int8 = [p for p in model.parameters() if hasattr(p, "group_size")]
|
| 171 |
+
|
| 172 |
+
############ Initialize wandb ############
|
| 173 |
+
run_config = dict(vars(args))
|
| 174 |
+
run_config.update(
|
| 175 |
+
{
|
| 176 |
+
"max_lr": run_config.pop("lr"), # rename lr to max_lr to avoid conflicts with scheduler
|
| 177 |
+
"total_params_M": n_total_params / 1_000_000,
|
| 178 |
+
"dataset": "c4",
|
| 179 |
+
"model": model_config.to_dict(),
|
| 180 |
+
"world_size": world_size,
|
| 181 |
+
"device": str(device),
|
| 182 |
+
}
|
| 183 |
+
)
|
| 184 |
+
|
| 185 |
+
if global_rank == 0:
|
| 186 |
+
if not args.unset_wandb:
|
| 187 |
+
wandb.config.update(run_config, allow_val_change=True)
|
| 188 |
+
wandb.save(os.path.abspath(__file__), policy="now") # save current script
|
| 189 |
+
# fix tqdm visual length to 80 so that the progress bar
|
| 190 |
+
# doesn't jump around when changing from external display to laptop
|
| 191 |
+
pbar = tqdm(total=args.num_training_steps - update_step, desc="Update steps", ncols=80)
|
| 192 |
+
|
| 193 |
+
############ Initialize optimization ############
|
| 194 |
+
if "galore" in args.optimizer.lower():
|
| 195 |
+
# make parameters with "rank" to a single group, if param_name has "mlp" or "attn"
|
| 196 |
+
lowrank_params = []
|
| 197 |
+
target_modules_list = ["attn", "mlp"]
|
| 198 |
+
for module_name, module in model.named_modules():
|
| 199 |
+
if not (isinstance(module, nn.Linear) or isinstance(module, QScaleLinear) or isinstance(module, QLinear)):
|
| 200 |
+
continue
|
| 201 |
+
if not any(target_key in module_name for target_key in target_modules_list):
|
| 202 |
+
continue
|
| 203 |
+
logger.info(f"Adding {module_name} to GaLore parameters")
|
| 204 |
+
lowrank_params.append(module.weight)
|
| 205 |
+
|
| 206 |
+
id_lowrank_params = [id(p) for p in lowrank_params]
|
| 207 |
+
# make parameters without "rank" to another group
|
| 208 |
+
regular_params = [p for p in model.parameters() if id(p) not in id_lowrank_params]
|
| 209 |
+
# then call low rank optimizer
|
| 210 |
+
param_groups = [
|
| 211 |
+
{"params": regular_params},
|
| 212 |
+
{
|
| 213 |
+
"params": lowrank_params,
|
| 214 |
+
"rank": args.rank,
|
| 215 |
+
"update_proj_gap": args.update_proj_gap,
|
| 216 |
+
"scale": args.galore_scale,
|
| 217 |
+
"proj_type": args.proj_type,
|
| 218 |
+
"quant": args.proj_quant,
|
| 219 |
+
"quant_n_bit": args.proj_bits,
|
| 220 |
+
"quant_group_size": args.proj_group_size,
|
| 221 |
+
"cos_threshold": args.cos_threshold,
|
| 222 |
+
"gamma_proj": args.gamma_proj,
|
| 223 |
+
"queue_size": args.queue_size,
|
| 224 |
+
},
|
| 225 |
+
]
|
| 226 |
+
elif "apollo" in args.optimizer.lower():
|
| 227 |
+
# make parameters with "rank" to a single group, if param_name has "mlp" or "attn"
|
| 228 |
+
lowrank_params = []
|
| 229 |
+
target_modules_list = ["attn", "mlp"]
|
| 230 |
+
for module_name, module in model.named_modules():
|
| 231 |
+
if not (isinstance(module, nn.Linear) or isinstance(module, QScaleLinear) or isinstance(module, QLinear)):
|
| 232 |
+
continue
|
| 233 |
+
if not any(target_key in module_name for target_key in target_modules_list):
|
| 234 |
+
continue
|
| 235 |
+
logger.info(f"Adding {module_name} to APOLLO parameters")
|
| 236 |
+
lowrank_params.append(module.weight)
|
| 237 |
+
|
| 238 |
+
id_lowrank_params = [id(p) for p in lowrank_params]
|
| 239 |
+
# make parameters without "rank" to another group
|
| 240 |
+
regular_params = [p for p in model.parameters() if id(p) not in id_lowrank_params]
|
| 241 |
+
# then call low rank optimizer
|
| 242 |
+
param_groups = [
|
| 243 |
+
{"params": regular_params},
|
| 244 |
+
{
|
| 245 |
+
"params": lowrank_params,
|
| 246 |
+
"rank": args.rank,
|
| 247 |
+
"update_proj_gap": args.update_proj_gap,
|
| 248 |
+
"scale": args.apollo_scale,
|
| 249 |
+
"proj_type": args.proj_type,
|
| 250 |
+
"proj": args.proj,
|
| 251 |
+
"scale_type": args.scale_type,
|
| 252 |
+
},
|
| 253 |
+
]
|
| 254 |
+
elif "conda" in args.optimizer.lower():
|
| 255 |
+
# make parameters with "rank" to a single group, if param_name has "mlp" or "attn"
|
| 256 |
+
lowrank_params = []
|
| 257 |
+
target_modules_list = ["attn", "mlp"]
|
| 258 |
+
for module_name, module in model.named_modules():
|
| 259 |
+
if not (isinstance(module, nn.Linear) or isinstance(module, QScaleLinear) or isinstance(module, QLinear)):
|
| 260 |
+
continue
|
| 261 |
+
if not any(target_key in module_name for target_key in target_modules_list):
|
| 262 |
+
continue
|
| 263 |
+
logger.info(f"Adding {module_name} to conda parameters")
|
| 264 |
+
lowrank_params.append(module.weight)
|
| 265 |
+
|
| 266 |
+
id_lowrank_params = [id(p) for p in lowrank_params]
|
| 267 |
+
# make parameters without "rank" to another group
|
| 268 |
+
regular_params = [p for p in model.parameters() if id(p) not in id_lowrank_params]
|
| 269 |
+
# then call low rank optimizer
|
| 270 |
+
param_groups = [
|
| 271 |
+
{"params": regular_params},
|
| 272 |
+
{
|
| 273 |
+
"params": lowrank_params,
|
| 274 |
+
"rank": args.rank,
|
| 275 |
+
"update_proj_gap": args.update_proj_gap,
|
| 276 |
+
"scale": args.apollo_scale,
|
| 277 |
+
"proj_type": args.proj_type,
|
| 278 |
+
"proj": args.proj,
|
| 279 |
+
"scale_type": args.scale_type,
|
| 280 |
+
},
|
| 281 |
+
]
|
| 282 |
+
else:
|
| 283 |
+
param_groups = None
|
| 284 |
+
id_lowrank_params = None
|
| 285 |
+
|
| 286 |
+
# print params and trainable params
|
| 287 |
+
logger.info(f"\n{model}\n")
|
| 288 |
+
logger.info(f"Total params: {sum(p.numel() for p in model.parameters()) / 1_000_000:.2f}M")
|
| 289 |
+
|
| 290 |
+
if args.simulation:
|
| 291 |
+
num_train_params = sum(p.numel() for p in trainable_params)
|
| 292 |
+
else:
|
| 293 |
+
num_train_params = sum(p.numel() for p in trainable_params) + sum(p.numel() for p in trainable_params_int8)
|
| 294 |
+
|
| 295 |
+
logger.info(f"Trainable params: {num_train_params / 1_000_000:.2f}M")
|
| 296 |
+
if "q_galore" in args.optimizer.lower():
|
| 297 |
+
logger.info(
|
| 298 |
+
f"Trainable params with Q-GaLore enabled: {sum(p.numel() for p in trainable_params_int8) / 1_000_000:.2f}M"
|
| 299 |
+
)
|
| 300 |
+
elif "galore" in args.optimizer.lower():
|
| 301 |
+
logger.info(f"Total params with GaLore enabled: {sum(p.numel() for p in lowrank_params) / 1_000_000:.2f}M")
|
| 302 |
+
elif "q_apollo" in args.optimizer.lower():
|
| 303 |
+
logger.info(
|
| 304 |
+
f"Trainable params with Q-APOLLO enabled: {sum(p.numel() for p in trainable_params_int8) / 1_000_000:.2f}M"
|
| 305 |
+
)
|
| 306 |
+
elif "apollo" in args.optimizer.lower():
|
| 307 |
+
logger.info(f"Total params with APOLLO enabled: {sum(p.numel() for p in lowrank_params) / 1_000_000:.2f}M")
|
| 308 |
+
|
| 309 |
+
logger.info(f"Saving model to {args.save_dir} every {args.save_every} update steps")
|
| 310 |
+
|
| 311 |
+
model, optimizer, scheduler, layer_wise_flag = setup_optimization(
|
| 312 |
+
args, model, trainable_params, param_groups, id_lowrank_params, model_config
|
| 313 |
+
)
|
| 314 |
+
|
| 315 |
+
if layer_wise_flag:
|
| 316 |
+
# will pass optimizer_dict and scheduler_dict out instead of optimizer and scheduler
|
| 317 |
+
optimizer_dict = optimizer
|
| 318 |
+
scheduler_dict = scheduler
|
| 319 |
+
|
| 320 |
+
# Bug-3 fix: wrap with DDP *before* torch.compile per PyTorch recommendation.
|
| 321 |
+
# This ensures gradient reduction hooks are correctly installed on the DDP module,
|
| 322 |
+
# and the compiled graph captures the full DDP+model forward pass.
|
| 323 |
+
# (Issue-5: optimizer.load_state_dict is called after both DDP and compile below.)
|
| 324 |
+
if not args.single_gpu:
|
| 325 |
+
model: LlamaForCausalLM = torch.nn.parallel.DistributedDataParallel(
|
| 326 |
+
model,
|
| 327 |
+
device_ids=[local_rank],
|
| 328 |
+
output_device=local_rank,
|
| 329 |
+
broadcast_buffers=False,
|
| 330 |
+
)
|
| 331 |
+
|
| 332 |
+
# compile the model (after DDP so the compiled graph includes DDP reduction)
|
| 333 |
+
if args.compile:
|
| 334 |
+
print("Compiling the model... (takes a ~minute)")
|
| 335 |
+
unoptimized_model = model
|
| 336 |
+
|
| 337 |
+
# Configure TorchDynamo to suppress errors and fall back to eager mode
|
| 338 |
+
import torch._dynamo
|
| 339 |
+
torch._dynamo.config.suppress_errors = args.dynamo_suppress_errors
|
| 340 |
+
torch._dynamo.config.verbose = False
|
| 341 |
+
# Set cache size limit to prevent memory issues during long training
|
| 342 |
+
torch._dynamo.config.cache_size_limit = args.dynamo_cache_limit
|
| 343 |
+
|
| 344 |
+
model = torch.compile(model) # requires PyTorch 2.0
|
| 345 |
+
|
| 346 |
+
# resume optimizer
|
| 347 |
+
if args.restore_optimizer and args.continue_from is not None:
|
| 348 |
+
logger.info("Restoring optimizer and scheduler from the checkpoint")
|
| 349 |
+
_optimizer_dir = args.continue_from
|
| 350 |
+
optimizer_checkpoint = torch.load(os.path.join(_optimizer_dir, "optimizer.pt"), map_location="cpu")
|
| 351 |
+
optimizer.load_state_dict(optimizer_checkpoint["optimizer"])
|
| 352 |
+
scheduler.load_state_dict(optimizer_checkpoint["scheduler"])
|
| 353 |
+
update_step = optimizer_checkpoint["update_step"]
|
| 354 |
+
beginning_step = update_step
|
| 355 |
+
global_step = optimizer_checkpoint["global_step"]
|
| 356 |
+
logger.info(f"Optimizer and scheduler restored from {_optimizer_dir}")
|
| 357 |
+
|
| 358 |
+
# ##############################
|
| 359 |
+
# TRAINING LOOP
|
| 360 |
+
# we use iterable dataset, so we may never go through all the data
|
| 361 |
+
# ##############################
|
| 362 |
+
# global steps and others are defined above
|
| 363 |
+
pad_idx = tokenizer.pad_token_id
|
| 364 |
+
update_time = time.time()
|
| 365 |
+
local_step = 0 # when continue_from is used, local_step != global_step
|
| 366 |
+
total_svd_count = 0
|
| 367 |
+
|
| 368 |
+
dataloader_iter = iter(dataloader)
|
| 369 |
+
|
| 370 |
+
# Issue-4 fix: accumulate loss across micro-batches so logged loss is the true
|
| 371 |
+
# gradient-accumulation average, not just the last micro-batch.
|
| 372 |
+
accumulated_loss = 0.0
|
| 373 |
+
|
| 374 |
+
# Skip data if resuming from checkpoint
|
| 375 |
+
if update_step != 0:
|
| 376 |
+
skip_batches = args.gradient_accumulation * update_step
|
| 377 |
+
logger.info(f"Skipping {skip_batches} batches to resume from update step {update_step}")
|
| 378 |
+
skipped = 0
|
| 379 |
+
for _ in range(skip_batches):
|
| 380 |
+
# Issue-6 fix: handle StopIteration during skip so all ranks stay aligned
|
| 381 |
+
try:
|
| 382 |
+
next(dataloader_iter)
|
| 383 |
+
except StopIteration:
|
| 384 |
+
logger.warning(
|
| 385 |
+
f"Dataset exhausted during skip at batch {skipped}/{skip_batches}; "
|
| 386 |
+
f"restarting iterator to keep ranks aligned."
|
| 387 |
+
)
|
| 388 |
+
dataloader_iter = iter(dataloader)
|
| 389 |
+
next(dataloader_iter)
|
| 390 |
+
skipped += 1
|
| 391 |
+
logger.info(f"Skipped {skipped} batches successfully")
|
| 392 |
+
|
| 393 |
+
while update_step <= args.num_training_steps:
|
| 394 |
+
try:
|
| 395 |
+
batch = next(dataloader_iter)
|
| 396 |
+
except StopIteration:
|
| 397 |
+
logger.info(f"Dataset completed one epoch. Starting new epoch with reshuffled data.")
|
| 398 |
+
dataloader_iter = iter(dataloader)
|
| 399 |
+
batch = next(dataloader_iter)
|
| 400 |
+
|
| 401 |
+
global_step += 1
|
| 402 |
+
local_step += 1
|
| 403 |
+
|
| 404 |
+
if update_step >= args.num_training_steps:
|
| 405 |
+
logger.info(f"Reached max number of update steps ({args.num_training_steps}). Stopping training.")
|
| 406 |
+
logger.info(f"Rank {global_rank} stopping training.")
|
| 407 |
+
break
|
| 408 |
+
|
| 409 |
+
# forward & backward
|
| 410 |
+
batch = {k: v.to(device) for k, v in batch.items()}
|
| 411 |
+
labels = batch["input_ids"].clone()
|
| 412 |
+
labels[labels == pad_idx] = -100
|
| 413 |
+
tokens_seen += (batch["input_ids"] != pad_idx).sum().item() * world_size
|
| 414 |
+
|
| 415 |
+
loss = model(**batch, labels=labels).loss
|
| 416 |
+
|
| 417 |
+
scaled_loss = loss / args.gradient_accumulation
|
| 418 |
+
scaled_loss.backward()
|
| 419 |
+
accumulated_loss += loss.item() # Issue-4: accumulate before the continue
|
| 420 |
+
|
| 421 |
+
if global_step % args.gradient_accumulation != 0:
|
| 422 |
+
continue
|
| 423 |
+
|
| 424 |
+
# The below code is only executed during the update step
|
| 425 |
+
# Issue-4: compute average loss over all micro-batches in this accumulation window
|
| 426 |
+
avg_loss = accumulated_loss / args.gradient_accumulation
|
| 427 |
+
accumulated_loss = 0.0 # reset for next accumulation window
|
| 428 |
+
# add grad clipping: TODO: add gradient clipping of int8 weight
|
| 429 |
+
if args.grad_clipping != 0.0:
|
| 430 |
+
torch.nn.utils.clip_grad_norm_(trainable_params, args.grad_clipping)
|
| 431 |
+
# Periodic memory cleanup to prevent symbolic tensor issues during long training
|
| 432 |
+
if global_step % args.memory_cleanup_frequency == 0:
|
| 433 |
+
torch.cuda.empty_cache()
|
| 434 |
+
# Clear TorchDynamo cache to prevent memory accumulation
|
| 435 |
+
if args.compile:
|
| 436 |
+
import torch._dynamo
|
| 437 |
+
torch._dynamo.reset()
|
| 438 |
+
|
| 439 |
+
if global_rank == 0:
|
| 440 |
+
pbar.update(1)
|
| 441 |
+
if not layer_wise_flag: # layer-wise updation is done during backward; requires gradient_accumulation equals 1
|
| 442 |
+
optimizer.step()
|
| 443 |
+
scheduler.step()
|
| 444 |
+
optimizer.zero_grad()
|
| 445 |
+
|
| 446 |
+
update_step += 1
|
| 447 |
+
update_time = time.time() - update_time
|
| 448 |
+
|
| 449 |
+
# save checkpoint by save_every
|
| 450 |
+
if local_step > args.gradient_accumulation and update_step % args.save_every == 0 and global_rank == 0:
|
| 451 |
+
current_model_directory = f"{args.save_dir}/model_{update_step}"
|
| 452 |
+
logger.info(f"Saving model and optimizer to {current_model_directory}, update step {update_step}")
|
| 453 |
+
os.makedirs(args.save_dir, exist_ok=True)
|
| 454 |
+
# Bug-1 fix: unwrap DDP/compiled model for saving; works in both single-GPU and multi-GPU modes
|
| 455 |
+
unwrapped_model = model.module if hasattr(model, 'module') else model
|
| 456 |
+
unwrapped_model.save_pretrained(current_model_directory, max_shard_size="500GB", from_pt=True)
|
| 457 |
+
saving_model_weight(unwrapped_model, f"{current_model_directory}/pytorch_model.bin", args)
|
| 458 |
+
|
| 459 |
+
optimizer_checkpoint = {
|
| 460 |
+
"optimizer": optimizer.state_dict(),
|
| 461 |
+
"scheduler": scheduler.state_dict(),
|
| 462 |
+
"update_step": update_step,
|
| 463 |
+
"global_step": global_step,
|
| 464 |
+
"config": run_config,
|
| 465 |
+
"wandb": wandb.run.dir if not args.unset_wandb else None,
|
| 466 |
+
"dtype": args.dtype,
|
| 467 |
+
}
|
| 468 |
+
torch.save(optimizer_checkpoint, f"{current_model_directory}/optimizer.pt")
|
| 469 |
+
|
| 470 |
+
training_state_checkpoint = {
|
| 471 |
+
"global_step": global_step,
|
| 472 |
+
"update_step": update_step,
|
| 473 |
+
"tokens_seen": tokens_seen,
|
| 474 |
+
"tokens_seen_before": tokens_seen_before,
|
| 475 |
+
"update_time": update_time,
|
| 476 |
+
}
|
| 477 |
+
with open(f"{current_model_directory}/training_state.json", "w") as f:
|
| 478 |
+
json.dump(training_state_checkpoint, f, indent=4)
|
| 479 |
+
|
| 480 |
+
# save wandb related info
|
| 481 |
+
if not args.unset_wandb:
|
| 482 |
+
wandb_info = {
|
| 483 |
+
"wandb_id": wandb.run.id,
|
| 484 |
+
}
|
| 485 |
+
with open(f"{args.save_dir}/wandb.json", "w") as f:
|
| 486 |
+
json.dump(wandb_info, f, indent=4)
|
| 487 |
+
|
| 488 |
+
# evaluation
|
| 489 |
+
if update_step % args.eval_every == 0:
|
| 490 |
+
logger.info(f"Performing evaluation at step {update_step}")
|
| 491 |
+
total_loss, evaluated_on_tokens, perplexity = evaluate_model(
|
| 492 |
+
model, tokenizer, pad_idx, global_rank, world_size, device, args
|
| 493 |
+
)
|
| 494 |
+
|
| 495 |
+
if global_rank == 0:
|
| 496 |
+
if not args.unset_wandb:
|
| 497 |
+
wandb.log(
|
| 498 |
+
{
|
| 499 |
+
"eval_loss": total_loss,
|
| 500 |
+
"eval_perplexity": perplexity,
|
| 501 |
+
"eval_tokens": evaluated_on_tokens,
|
| 502 |
+
},
|
| 503 |
+
step=update_step,
|
| 504 |
+
)
|
| 505 |
+
logger.info(f"Eval loss at step {update_step}: {total_loss}, Eval perplexity: {perplexity}")
|
| 506 |
+
|
| 507 |
+
if not layer_wise_flag:
|
| 508 |
+
lr = optimizer.param_groups[0]["lr"]
|
| 509 |
+
else:
|
| 510 |
+
lr = list(optimizer_dict.values())[0].param_groups[0]["lr"]
|
| 511 |
+
tokens_in_update = tokens_seen - tokens_seen_before
|
| 512 |
+
tokens_seen_before = tokens_seen
|
| 513 |
+
batches_in_update = args.gradient_accumulation * world_size
|
| 514 |
+
if not layer_wise_flag:
|
| 515 |
+
total_svd_count = getting_svd_cnt(optimizer)
|
| 516 |
+
else:
|
| 517 |
+
total_svd_count = 0
|
| 518 |
+
|
| 519 |
+
if global_rank == 0:
|
| 520 |
+
if not args.unset_wandb:
|
| 521 |
+
wandb.log(
|
| 522 |
+
{
|
| 523 |
+
"loss": avg_loss,
|
| 524 |
+
"lr": lr,
|
| 525 |
+
"update_step": update_step,
|
| 526 |
+
"tokens_seen": tokens_seen,
|
| 527 |
+
"total_svd_count": total_svd_count,
|
| 528 |
+
"throughput_tokens": tokens_in_update / update_time,
|
| 529 |
+
"throughput_examples": args.total_batch_size / update_time,
|
| 530 |
+
"throughput_batches": batches_in_update / update_time,
|
| 531 |
+
},
|
| 532 |
+
step=update_step,
|
| 533 |
+
)
|
| 534 |
+
update_time = time.time()
|
| 535 |
+
|
| 536 |
+
# ##############################
|
| 537 |
+
# END of training loop
|
| 538 |
+
# ##############################
|
| 539 |
+
logger.info("Training finished")
|
| 540 |
+
if global_rank == 0:
|
| 541 |
+
pbar.close()
|
| 542 |
+
|
| 543 |
+
current_model_directory = f"{args.save_dir}/model_{update_step}"
|
| 544 |
+
if global_rank == 0 and not os.path.exists(current_model_directory):
|
| 545 |
+
logger.info(f"Saving model and optimizer to {current_model_directory}, update step {update_step}")
|
| 546 |
+
os.makedirs(args.save_dir, exist_ok=True)
|
| 547 |
+
# Bug-1 fix: unwrap DDP/compiled model for saving; works in both single-GPU and multi-GPU modes
|
| 548 |
+
unwrapped_model = model.module if hasattr(model, 'module') else model
|
| 549 |
+
unwrapped_model.save_pretrained(current_model_directory, max_shard_size="500GB", from_pt=True)
|
| 550 |
+
saving_model_weight(unwrapped_model, f"{current_model_directory}/pytorch_model.bin", args)
|
| 551 |
+
|
| 552 |
+
optimizer_checkpoint = {
|
| 553 |
+
"optimizer": optimizer.state_dict(),
|
| 554 |
+
"scheduler": scheduler.state_dict(),
|
| 555 |
+
"update_step": update_step,
|
| 556 |
+
"global_step": global_step,
|
| 557 |
+
"config": run_config,
|
| 558 |
+
"wandb": wandb.run.dir if not args.unset_wandb else None,
|
| 559 |
+
"dtype": args.dtype,
|
| 560 |
+
}
|
| 561 |
+
torch.save(optimizer_checkpoint, f"{current_model_directory}/optimizer.pt")
|
| 562 |
+
|
| 563 |
+
training_state_checkpoint = {
|
| 564 |
+
"global_step": global_step,
|
| 565 |
+
"update_step": update_step,
|
| 566 |
+
"tokens_seen": tokens_seen,
|
| 567 |
+
"tokens_seen_before": tokens_seen_before,
|
| 568 |
+
"update_time": update_time,
|
| 569 |
+
}
|
| 570 |
+
with open(f"{current_model_directory}/training_state.json", "w") as f:
|
| 571 |
+
json.dump(training_state_checkpoint, f, indent=4)
|
| 572 |
+
|
| 573 |
+
# Final evaluation
|
| 574 |
+
logger.info("Running final evaluation")
|
| 575 |
+
model.eval()
|
| 576 |
+
del loss, optimizer, scheduler
|
| 577 |
+
import gc
|
| 578 |
+
|
| 579 |
+
gc.collect()
|
| 580 |
+
torch.cuda.empty_cache()
|
| 581 |
+
|
| 582 |
+
total_loss, evaluated_on_tokens, perplexity = evaluate_model(model, tokenizer, pad_idx, global_rank, world_size, device, args)
|
| 583 |
+
|
| 584 |
+
if global_rank == 0:
|
| 585 |
+
if not args.unset_wandb:
|
| 586 |
+
wandb.log(
|
| 587 |
+
{
|
| 588 |
+
"final_eval_loss": total_loss,
|
| 589 |
+
"final_eval_perplexity": perplexity,
|
| 590 |
+
"final_eval_tokens": evaluated_on_tokens,
|
| 591 |
+
},
|
| 592 |
+
step=update_step,
|
| 593 |
+
)
|
| 594 |
+
logger.info(f"Final eval loss: {total_loss}, Final eval perplexity: {perplexity}")
|
| 595 |
+
|
| 596 |
+
logger.info("Script finished successfully")
|
| 597 |
+
print(f"Rank {global_rank} finished successfully")
|
| 598 |
+
|
| 599 |
+
|
| 600 |
+
if __name__ == "__main__":
|
| 601 |
+
print("Starting script")
|
| 602 |
+
args = parse_args(None)
|
| 603 |
+
main(args)
|
1b_archs_fwe/transformer_1b_fwe_lion_lr3e_4_b1_0_9_b2_0_99_20260519_061741/exp_data/flame/components/__init__.py
ADDED
|
File without changes
|
1b_archs_fwe/transformer_1b_fwe_lion_lr3e_4_b1_0_9_b2_0_99_20260519_061741/exp_data/flame/components/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (197 Bytes). View file
|
|
|
1b_archs_fwe/transformer_1b_fwe_lion_lr3e_4_b1_0_9_b2_0_99_20260519_061741/exp_data/flame/components/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (235 Bytes). View file
|
|
|
1b_archs_fwe/transformer_1b_fwe_lion_lr3e_4_b1_0_9_b2_0_99_20260519_061741/exp_data/flame/components/__pycache__/checkpoint.cpython-310.pyc
ADDED
|
Binary file (1.92 kB). View file
|
|
|
1b_archs_fwe/transformer_1b_fwe_lion_lr3e_4_b1_0_9_b2_0_99_20260519_061741/exp_data/flame/components/__pycache__/checkpoint.cpython-311.pyc
ADDED
|
Binary file (3.7 kB). View file
|
|
|
1b_archs_fwe/transformer_1b_fwe_lion_lr3e_4_b1_0_9_b2_0_99_20260519_061741/exp_data/flame/components/checkpoint.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
from dataclasses import dataclass, field
|
| 8 |
+
from datetime import timedelta
|
| 9 |
+
from io import BytesIO
|
| 10 |
+
from typing import Any, Dict, List
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
from torch.distributed.checkpoint.stateful import Stateful
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
@dataclass
|
| 17 |
+
class TrainState(Stateful):
|
| 18 |
+
step: int = 0
|
| 19 |
+
skipped_step: int = 0
|
| 20 |
+
token: int = 0
|
| 21 |
+
elapsed: timedelta = timedelta(0)
|
| 22 |
+
global_avg_losses: List[float] = field(default_factory=list)
|
| 23 |
+
global_max_losses: List[float] = field(default_factory=list)
|
| 24 |
+
log_steps: List[int] = field(default_factory=list)
|
| 25 |
+
|
| 26 |
+
def state_dict(self) -> Dict[str, Any]:
|
| 27 |
+
# Only checkpoint global_avg_losses and global_max_losses per log frequency
|
| 28 |
+
# to avoid sync overhead in every iteration.
|
| 29 |
+
global_avg_losses_bytes = BytesIO()
|
| 30 |
+
torch.save(self.global_avg_losses, global_avg_losses_bytes)
|
| 31 |
+
global_max_losses_bytes = BytesIO()
|
| 32 |
+
torch.save(self.global_max_losses, global_max_losses_bytes)
|
| 33 |
+
log_steps_bytes = BytesIO()
|
| 34 |
+
torch.save(self.log_steps, log_steps_bytes)
|
| 35 |
+
return {
|
| 36 |
+
"step": torch.tensor(self.step, dtype=torch.int32),
|
| 37 |
+
"skipped_step": torch.tensor(self.skipped_step, dtype=torch.int32),
|
| 38 |
+
"token": torch.tensor(self.token, dtype=torch.int64),
|
| 39 |
+
"elapsed": self.elapsed,
|
| 40 |
+
"global_avg_losses": global_avg_losses_bytes,
|
| 41 |
+
"global_max_losses": global_max_losses_bytes,
|
| 42 |
+
"log_steps": log_steps_bytes,
|
| 43 |
+
}
|
| 44 |
+
|
| 45 |
+
def load_state_dict(self, state_dict) -> None:
|
| 46 |
+
self.step = state_dict["step"].item()
|
| 47 |
+
self.skipped_step = state_dict.get("skipped_step", 0).item()
|
| 48 |
+
self.token = state_dict["token"].item()
|
| 49 |
+
self.elapsed = state_dict["elapsed"]
|
| 50 |
+
state_dict["global_avg_losses"].seek(0)
|
| 51 |
+
self.global_avg_losses = torch.load(
|
| 52 |
+
state_dict["global_avg_losses"], weights_only=False
|
| 53 |
+
)
|
| 54 |
+
state_dict["global_max_losses"].seek(0)
|
| 55 |
+
self.global_max_losses = torch.load(
|
| 56 |
+
state_dict["global_max_losses"], weights_only=False
|
| 57 |
+
)
|
| 58 |
+
state_dict["log_steps"].seek(0)
|
| 59 |
+
self.log_steps = torch.load(state_dict["log_steps"], weights_only=False)
|