Add transformer_pp_1b_c4/transformer_pp_1b_c4_valc4_soap_pdim2048_pfreq10_lr3e_3_b1_0_9_b2_0_95_eps_1e_15_20260508_191338
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- .gitattributes +9 -0
- transformer_pp_1b_c4/transformer_pp_1b_c4_valc4_soap_pdim2048_pfreq10_lr3e_3_b1_0_9_b2_0_95_eps_1e_15_20260508_191338/exp_data/LICENSE +21 -0
- transformer_pp_1b_c4/transformer_pp_1b_c4_valc4_soap_pdim2048_pfreq10_lr3e_3_b1_0_9_b2_0_95_eps_1e_15_20260508_191338/exp_data/README.md +519 -0
- transformer_pp_1b_c4/transformer_pp_1b_c4_valc4_soap_pdim2048_pfreq10_lr3e_3_b1_0_9_b2_0_95_eps_1e_15_20260508_191338/exp_data/checkpoint/step-30720/.metadata +3 -0
- transformer_pp_1b_c4/transformer_pp_1b_c4_valc4_soap_pdim2048_pfreq10_lr3e_3_b1_0_9_b2_0_95_eps_1e_15_20260508_191338/exp_data/checkpoint/step-30720/__0_0.distcp +3 -0
- transformer_pp_1b_c4/transformer_pp_1b_c4_valc4_soap_pdim2048_pfreq10_lr3e_3_b1_0_9_b2_0_95_eps_1e_15_20260508_191338/exp_data/checkpoint/step-30720/__1_0.distcp +3 -0
- transformer_pp_1b_c4/transformer_pp_1b_c4_valc4_soap_pdim2048_pfreq10_lr3e_3_b1_0_9_b2_0_95_eps_1e_15_20260508_191338/exp_data/checkpoint/step-30720/__2_0.distcp +3 -0
- transformer_pp_1b_c4/transformer_pp_1b_c4_valc4_soap_pdim2048_pfreq10_lr3e_3_b1_0_9_b2_0_95_eps_1e_15_20260508_191338/exp_data/checkpoint/step-30720/__3_0.distcp +3 -0
- transformer_pp_1b_c4/transformer_pp_1b_c4_valc4_soap_pdim2048_pfreq10_lr3e_3_b1_0_9_b2_0_95_eps_1e_15_20260508_191338/exp_data/checkpoint/step-30720/__4_0.distcp +3 -0
- transformer_pp_1b_c4/transformer_pp_1b_c4_valc4_soap_pdim2048_pfreq10_lr3e_3_b1_0_9_b2_0_95_eps_1e_15_20260508_191338/exp_data/checkpoint/step-30720/__5_0.distcp +3 -0
- transformer_pp_1b_c4/transformer_pp_1b_c4_valc4_soap_pdim2048_pfreq10_lr3e_3_b1_0_9_b2_0_95_eps_1e_15_20260508_191338/exp_data/checkpoint/step-30720/__6_0.distcp +3 -0
- transformer_pp_1b_c4/transformer_pp_1b_c4_valc4_soap_pdim2048_pfreq10_lr3e_3_b1_0_9_b2_0_95_eps_1e_15_20260508_191338/exp_data/checkpoint/step-30720/__7_0.distcp +3 -0
- transformer_pp_1b_c4/transformer_pp_1b_c4_valc4_soap_pdim2048_pfreq10_lr3e_3_b1_0_9_b2_0_95_eps_1e_15_20260508_191338/exp_data/config.json +34 -0
- transformer_pp_1b_c4/transformer_pp_1b_c4_valc4_soap_pdim2048_pfreq10_lr3e_3_b1_0_9_b2_0_95_eps_1e_15_20260508_191338/exp_data/configs/blt_transformer_1000hash.json +98 -0
- transformer_pp_1b_c4/transformer_pp_1b_c4_valc4_soap_pdim2048_pfreq10_lr3e_3_b1_0_9_b2_0_95_eps_1e_15_20260508_191338/exp_data/configs/blt_transformer_1_5B.json +99 -0
- transformer_pp_1b_c4/transformer_pp_1b_c4_valc4_soap_pdim2048_pfreq10_lr3e_3_b1_0_9_b2_0_95_eps_1e_15_20260508_191338/exp_data/configs/blt_transformer_380M.json +98 -0
- transformer_pp_1b_c4/transformer_pp_1b_c4_valc4_soap_pdim2048_pfreq10_lr3e_3_b1_0_9_b2_0_95_eps_1e_15_20260508_191338/exp_data/configs/delta_net_1B.json +29 -0
- transformer_pp_1b_c4/transformer_pp_1b_c4_valc4_soap_pdim2048_pfreq10_lr3e_3_b1_0_9_b2_0_95_eps_1e_15_20260508_191338/exp_data/configs/delta_net_340M.json +26 -0
- transformer_pp_1b_c4/transformer_pp_1b_c4_valc4_soap_pdim2048_pfreq10_lr3e_3_b1_0_9_b2_0_95_eps_1e_15_20260508_191338/exp_data/configs/gated_deltanet_1B.json +22 -0
- transformer_pp_1b_c4/transformer_pp_1b_c4_valc4_soap_pdim2048_pfreq10_lr3e_3_b1_0_9_b2_0_95_eps_1e_15_20260508_191338/exp_data/configs/gated_deltanet_340M.json +22 -0
- transformer_pp_1b_c4/transformer_pp_1b_c4_valc4_soap_pdim2048_pfreq10_lr3e_3_b1_0_9_b2_0_95_eps_1e_15_20260508_191338/exp_data/configs/gated_deltanet_h_340M.json +28 -0
- transformer_pp_1b_c4/transformer_pp_1b_c4_valc4_soap_pdim2048_pfreq10_lr3e_3_b1_0_9_b2_0_95_eps_1e_15_20260508_191338/exp_data/configs/gla_340M.json +24 -0
- transformer_pp_1b_c4/transformer_pp_1b_c4_valc4_soap_pdim2048_pfreq10_lr3e_3_b1_0_9_b2_0_95_eps_1e_15_20260508_191338/exp_data/configs/gla_7B.json +25 -0
- transformer_pp_1b_c4/transformer_pp_1b_c4_valc4_soap_pdim2048_pfreq10_lr3e_3_b1_0_9_b2_0_95_eps_1e_15_20260508_191338/exp_data/configs/gsa_340M.json +29 -0
- transformer_pp_1b_c4/transformer_pp_1b_c4_valc4_soap_pdim2048_pfreq10_lr3e_3_b1_0_9_b2_0_95_eps_1e_15_20260508_191338/exp_data/configs/mergenet_340M.json +34 -0
- transformer_pp_1b_c4/transformer_pp_1b_c4_valc4_soap_pdim2048_pfreq10_lr3e_3_b1_0_9_b2_0_95_eps_1e_15_20260508_191338/exp_data/configs/mergenet_64M.json +34 -0
- transformer_pp_1b_c4/transformer_pp_1b_c4_valc4_soap_pdim2048_pfreq10_lr3e_3_b1_0_9_b2_0_95_eps_1e_15_20260508_191338/exp_data/configs/qwen3_next_1B.json +44 -0
- transformer_pp_1b_c4/transformer_pp_1b_c4_valc4_soap_pdim2048_pfreq10_lr3e_3_b1_0_9_b2_0_95_eps_1e_15_20260508_191338/exp_data/configs/qwen3_next_350M.json +44 -0
- transformer_pp_1b_c4/transformer_pp_1b_c4_valc4_soap_pdim2048_pfreq10_lr3e_3_b1_0_9_b2_0_95_eps_1e_15_20260508_191338/exp_data/configs/transformer_1B.json +22 -0
- transformer_pp_1b_c4/transformer_pp_1b_c4_valc4_soap_pdim2048_pfreq10_lr3e_3_b1_0_9_b2_0_95_eps_1e_15_20260508_191338/exp_data/configs/transformer_340M.json +18 -0
- transformer_pp_1b_c4/transformer_pp_1b_c4_valc4_soap_pdim2048_pfreq10_lr3e_3_b1_0_9_b2_0_95_eps_1e_15_20260508_191338/exp_data/configs/transformer_7B.json +21 -0
- transformer_pp_1b_c4/transformer_pp_1b_c4_valc4_soap_pdim2048_pfreq10_lr3e_3_b1_0_9_b2_0_95_eps_1e_15_20260508_191338/exp_data/flame/__init__.py +1 -0
- transformer_pp_1b_c4/transformer_pp_1b_c4_valc4_soap_pdim2048_pfreq10_lr3e_3_b1_0_9_b2_0_95_eps_1e_15_20260508_191338/exp_data/flame/__pycache__/__init__.cpython-310.pyc +0 -0
- transformer_pp_1b_c4/transformer_pp_1b_c4_valc4_soap_pdim2048_pfreq10_lr3e_3_b1_0_9_b2_0_95_eps_1e_15_20260508_191338/exp_data/flame/__pycache__/__init__.cpython-311.pyc +0 -0
- transformer_pp_1b_c4/transformer_pp_1b_c4_valc4_soap_pdim2048_pfreq10_lr3e_3_b1_0_9_b2_0_95_eps_1e_15_20260508_191338/exp_data/flame/__pycache__/__init__.cpython-313.pyc +0 -0
- transformer_pp_1b_c4/transformer_pp_1b_c4_valc4_soap_pdim2048_pfreq10_lr3e_3_b1_0_9_b2_0_95_eps_1e_15_20260508_191338/exp_data/flame/__pycache__/config_manager.cpython-310.pyc +0 -0
- transformer_pp_1b_c4/transformer_pp_1b_c4_valc4_soap_pdim2048_pfreq10_lr3e_3_b1_0_9_b2_0_95_eps_1e_15_20260508_191338/exp_data/flame/__pycache__/config_manager.cpython-311.pyc +0 -0
- transformer_pp_1b_c4/transformer_pp_1b_c4_valc4_soap_pdim2048_pfreq10_lr3e_3_b1_0_9_b2_0_95_eps_1e_15_20260508_191338/exp_data/flame/__pycache__/data.cpython-310.pyc +0 -0
- transformer_pp_1b_c4/transformer_pp_1b_c4_valc4_soap_pdim2048_pfreq10_lr3e_3_b1_0_9_b2_0_95_eps_1e_15_20260508_191338/exp_data/flame/__pycache__/data.cpython-311.pyc +0 -0
- transformer_pp_1b_c4/transformer_pp_1b_c4_valc4_soap_pdim2048_pfreq10_lr3e_3_b1_0_9_b2_0_95_eps_1e_15_20260508_191338/exp_data/flame/__pycache__/train.cpython-310.pyc +0 -0
- transformer_pp_1b_c4/transformer_pp_1b_c4_valc4_soap_pdim2048_pfreq10_lr3e_3_b1_0_9_b2_0_95_eps_1e_15_20260508_191338/exp_data/flame/__pycache__/train.cpython-311.pyc +0 -0
- transformer_pp_1b_c4/transformer_pp_1b_c4_valc4_soap_pdim2048_pfreq10_lr3e_3_b1_0_9_b2_0_95_eps_1e_15_20260508_191338/exp_data/flame/__pycache__/train.cpython-313.pyc +0 -0
- transformer_pp_1b_c4/transformer_pp_1b_c4_valc4_soap_pdim2048_pfreq10_lr3e_3_b1_0_9_b2_0_95_eps_1e_15_20260508_191338/exp_data/flame/c4_test.py +603 -0
- transformer_pp_1b_c4/transformer_pp_1b_c4_valc4_soap_pdim2048_pfreq10_lr3e_3_b1_0_9_b2_0_95_eps_1e_15_20260508_191338/exp_data/flame/components/__init__.py +0 -0
- transformer_pp_1b_c4/transformer_pp_1b_c4_valc4_soap_pdim2048_pfreq10_lr3e_3_b1_0_9_b2_0_95_eps_1e_15_20260508_191338/exp_data/flame/components/__pycache__/__init__.cpython-310.pyc +0 -0
- transformer_pp_1b_c4/transformer_pp_1b_c4_valc4_soap_pdim2048_pfreq10_lr3e_3_b1_0_9_b2_0_95_eps_1e_15_20260508_191338/exp_data/flame/components/__pycache__/__init__.cpython-311.pyc +0 -0
- transformer_pp_1b_c4/transformer_pp_1b_c4_valc4_soap_pdim2048_pfreq10_lr3e_3_b1_0_9_b2_0_95_eps_1e_15_20260508_191338/exp_data/flame/components/__pycache__/checkpoint.cpython-310.pyc +0 -0
- transformer_pp_1b_c4/transformer_pp_1b_c4_valc4_soap_pdim2048_pfreq10_lr3e_3_b1_0_9_b2_0_95_eps_1e_15_20260508_191338/exp_data/flame/components/__pycache__/checkpoint.cpython-311.pyc +0 -0
- transformer_pp_1b_c4/transformer_pp_1b_c4_valc4_soap_pdim2048_pfreq10_lr3e_3_b1_0_9_b2_0_95_eps_1e_15_20260508_191338/exp_data/flame/components/checkpoint.py +59 -0
- transformer_pp_1b_c4/transformer_pp_1b_c4_valc4_soap_pdim2048_pfreq10_lr3e_3_b1_0_9_b2_0_95_eps_1e_15_20260508_191338/exp_data/flame/config_manager.py +981 -0
.gitattributes
CHANGED
|
@@ -483,3 +483,12 @@ transformer_pp_1b_c4/transformer_pp_1b_c4_valc4_marsadamw_lr3e_3_b1_0_95_b2_0_99
|
|
| 483 |
transformer_pp_1b_c4/transformer_pp_1b_c4_valc4_marsadamw_lr3e_3_b1_0_95_b2_0_99_eps_1e_15_20260511_072915/exp_data/checkpoint/step-30720/__5_0.distcp filter=lfs diff=lfs merge=lfs -text
|
| 484 |
transformer_pp_1b_c4/transformer_pp_1b_c4_valc4_marsadamw_lr3e_3_b1_0_95_b2_0_99_eps_1e_15_20260511_072915/exp_data/checkpoint/step-30720/__6_0.distcp filter=lfs diff=lfs merge=lfs -text
|
| 485 |
transformer_pp_1b_c4/transformer_pp_1b_c4_valc4_marsadamw_lr3e_3_b1_0_95_b2_0_99_eps_1e_15_20260511_072915/exp_data/checkpoint/step-30720/__7_0.distcp filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 483 |
transformer_pp_1b_c4/transformer_pp_1b_c4_valc4_marsadamw_lr3e_3_b1_0_95_b2_0_99_eps_1e_15_20260511_072915/exp_data/checkpoint/step-30720/__5_0.distcp filter=lfs diff=lfs merge=lfs -text
|
| 484 |
transformer_pp_1b_c4/transformer_pp_1b_c4_valc4_marsadamw_lr3e_3_b1_0_95_b2_0_99_eps_1e_15_20260511_072915/exp_data/checkpoint/step-30720/__6_0.distcp filter=lfs diff=lfs merge=lfs -text
|
| 485 |
transformer_pp_1b_c4/transformer_pp_1b_c4_valc4_marsadamw_lr3e_3_b1_0_95_b2_0_99_eps_1e_15_20260511_072915/exp_data/checkpoint/step-30720/__7_0.distcp filter=lfs diff=lfs merge=lfs -text
|
| 486 |
+
transformer_pp_1b_c4/transformer_pp_1b_c4_valc4_soap_pdim2048_pfreq10_lr3e_3_b1_0_9_b2_0_95_eps_1e_15_20260508_191338/exp_data/checkpoint/step-30720/.metadata filter=lfs diff=lfs merge=lfs -text
|
| 487 |
+
transformer_pp_1b_c4/transformer_pp_1b_c4_valc4_soap_pdim2048_pfreq10_lr3e_3_b1_0_9_b2_0_95_eps_1e_15_20260508_191338/exp_data/checkpoint/step-30720/__0_0.distcp filter=lfs diff=lfs merge=lfs -text
|
| 488 |
+
transformer_pp_1b_c4/transformer_pp_1b_c4_valc4_soap_pdim2048_pfreq10_lr3e_3_b1_0_9_b2_0_95_eps_1e_15_20260508_191338/exp_data/checkpoint/step-30720/__1_0.distcp filter=lfs diff=lfs merge=lfs -text
|
| 489 |
+
transformer_pp_1b_c4/transformer_pp_1b_c4_valc4_soap_pdim2048_pfreq10_lr3e_3_b1_0_9_b2_0_95_eps_1e_15_20260508_191338/exp_data/checkpoint/step-30720/__2_0.distcp filter=lfs diff=lfs merge=lfs -text
|
| 490 |
+
transformer_pp_1b_c4/transformer_pp_1b_c4_valc4_soap_pdim2048_pfreq10_lr3e_3_b1_0_9_b2_0_95_eps_1e_15_20260508_191338/exp_data/checkpoint/step-30720/__3_0.distcp filter=lfs diff=lfs merge=lfs -text
|
| 491 |
+
transformer_pp_1b_c4/transformer_pp_1b_c4_valc4_soap_pdim2048_pfreq10_lr3e_3_b1_0_9_b2_0_95_eps_1e_15_20260508_191338/exp_data/checkpoint/step-30720/__4_0.distcp filter=lfs diff=lfs merge=lfs -text
|
| 492 |
+
transformer_pp_1b_c4/transformer_pp_1b_c4_valc4_soap_pdim2048_pfreq10_lr3e_3_b1_0_9_b2_0_95_eps_1e_15_20260508_191338/exp_data/checkpoint/step-30720/__5_0.distcp filter=lfs diff=lfs merge=lfs -text
|
| 493 |
+
transformer_pp_1b_c4/transformer_pp_1b_c4_valc4_soap_pdim2048_pfreq10_lr3e_3_b1_0_9_b2_0_95_eps_1e_15_20260508_191338/exp_data/checkpoint/step-30720/__6_0.distcp filter=lfs diff=lfs merge=lfs -text
|
| 494 |
+
transformer_pp_1b_c4/transformer_pp_1b_c4_valc4_soap_pdim2048_pfreq10_lr3e_3_b1_0_9_b2_0_95_eps_1e_15_20260508_191338/exp_data/checkpoint/step-30720/__7_0.distcp filter=lfs diff=lfs merge=lfs -text
|
transformer_pp_1b_c4/transformer_pp_1b_c4_valc4_soap_pdim2048_pfreq10_lr3e_3_b1_0_9_b2_0_95_eps_1e_15_20260508_191338/exp_data/LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2023-2025 Songlin Yang, Yu Zhang
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
transformer_pp_1b_c4/transformer_pp_1b_c4_valc4_soap_pdim2048_pfreq10_lr3e_3_b1_0_9_b2_0_95_eps_1e_15_20260508_191338/exp_data/README.md
ADDED
|
@@ -0,0 +1,519 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<div align="center">
|
| 2 |
+
|
| 3 |
+
# 🔥 Flame: Flash Language Modeling Made Easy
|
| 4 |
+
|
| 5 |
+
[](https://deepwiki.com/fla-org/flame)
|
| 6 |
+
|
| 7 |
+
</div>
|
| 8 |
+
|
| 9 |
+
Welcome to 🔥 `flame`, a minimal and efficient framework built on `torchtitan` for language models with blazing efficiency.
|
| 10 |
+
|
| 11 |
+
**Feature Highlights:**
|
| 12 |
+
|
| 13 |
+
- 🚀 Minimal, easy-to-use, extensible training framework
|
| 14 |
+
- 🤗 Seamless integration with `fla` and `transformers`
|
| 15 |
+
- 🔄 Zero-cost data preprocessing: online tokenization, dataset shuffling, and multiple datasets support
|
| 16 |
+
- 🔮 4D parallelism (coming soon)
|
| 17 |
+
|
| 18 |
+
## Setup
|
| 19 |
+
|
| 20 |
+
To get started, clone the `flame` repository and install the required dependencies:
|
| 21 |
+
|
| 22 |
+
```bash
|
| 23 |
+
git clone https://github.com/fla-org/flame.git
|
| 24 |
+
cd flame
|
| 25 |
+
pip install .
|
| 26 |
+
```
|
| 27 |
+
|
| 28 |
+
Install the latest version of fla
|
| 29 |
+
```
|
| 30 |
+
pip uninstall flash-linear-attention && pip install -U --no-use-pep517 git+https://github.com/fla-org/flash-linear-attention
|
| 31 |
+
```
|
| 32 |
+
|
| 33 |
+
[Important] Install specific version of torchtitan
|
| 34 |
+
```
|
| 35 |
+
pip install git+https://github.com/pytorch/torchtitan.git@0b44d4c
|
| 36 |
+
```
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
## Dataset Preparation
|
| 40 |
+
To download the dataset to your local disk, create a new Python file with the following content and execute it:
|
| 41 |
+
|
| 42 |
+
```py
|
| 43 |
+
from datasets import load_dataset
|
| 44 |
+
|
| 45 |
+
# load fineweb-edu with parallel processing
|
| 46 |
+
dataset = load_dataset("HuggingFaceFW/fineweb-edu", name="default", num_proc=64, cache_dir="/your/cache/path")
|
| 47 |
+
|
| 48 |
+
# or load a subset with roughly 100B tokens, suitable for small- or medium-sized experiments
|
| 49 |
+
dataset = load_dataset("HuggingFaceFW/fineweb-edu", name="sample-100BT", num_proc=64, cache_dir="/your/cache/path")
|
| 50 |
+
```
|
| 51 |
+
|
| 52 |
+
## Training Recipes
|
| 53 |
+
|
| 54 |
+
Here's an example of training a 340M FLA Transformer model with a LLaMA-like architecture from scratch on a 100BT subset of the Fineweb-edu corpus ~~in streaming mode~~. (Do not use streaming mode if you are concerned about resuming training.)
|
| 55 |
+
|
| 56 |
+
> [!WARNING]
|
| 57 |
+
> If the dataset is not downloaded beforehand, the streaming mode will attempt to fetch it from a remote server and download it on-the-fly, which can be highly unstable during training due to network issues.
|
| 58 |
+
> For stable training, ensure the dataset is downloaded locally (see [**Dataset Preparation**](#dataset-preparation)). Otherwise, we assume you are only testing the new corpus.
|
| 59 |
+
|
| 60 |
+
```sh
|
| 61 |
+
bash train.sh \
|
| 62 |
+
--job.config_file flame/models/fla.toml \
|
| 63 |
+
--job.dump_folder exp/transformer-340M-4K-10B/batch1.seqlen65536.context4096.warmup1024.update1.steps20480.lr1e-3.cosine \
|
| 64 |
+
--model.config configs/transformer_340M.json \
|
| 65 |
+
--model.tokenizer_path fla-hub/transformer-1.3B-100B \
|
| 66 |
+
--optimizer.name AdamW \
|
| 67 |
+
--optimizer.eps 1e-15 \
|
| 68 |
+
--optimizer.lr 1e-3 \
|
| 69 |
+
--lr_scheduler.warmup_steps 1024 \
|
| 70 |
+
--lr_scheduler.lr_min 0.1 \
|
| 71 |
+
--lr_scheduler.decay_type cosine \
|
| 72 |
+
--training.batch_size 1 \
|
| 73 |
+
--training.seq_len 65536 \
|
| 74 |
+
--training.context_len 4096 \
|
| 75 |
+
--training.varlen \
|
| 76 |
+
--training.gradient_accumulation_steps 1 \
|
| 77 |
+
--training.steps 20480 \
|
| 78 |
+
--training.max_norm 1.0 \
|
| 79 |
+
--training.skip_nan_inf \
|
| 80 |
+
--training.dataset HuggingFaceFW/fineweb-edu \
|
| 81 |
+
--training.dataset_name sample-100BT \
|
| 82 |
+
--training.dataset_split train \
|
| 83 |
+
--training.num_workers 32 \
|
| 84 |
+
--training.prefetch_factor 2 \
|
| 85 |
+
--training.seed 42 \
|
| 86 |
+
--training.compile \
|
| 87 |
+
--checkpoint.interval 2048 \
|
| 88 |
+
--checkpoint.load_step -1 \
|
| 89 |
+
--checkpoint.keep_latest_k 2 \
|
| 90 |
+
--metrics.log_freq 1
|
| 91 |
+
```
|
| 92 |
+
|
| 93 |
+
You can specify the number of GPUs by setting the environment variable `NGPU`, which defaults to 8.
|
| 94 |
+
**For single-GPU debugging, set `NGPU=1`.**
|
| 95 |
+
|
| 96 |
+
We provide several [config files](https://github.com/fla-org/flame/tree/main/configs) for different models.
|
| 97 |
+
By default, the learning rate is set to 1e-3 with a cosine scheduler. Other schedulers, such as WSD (wsd), are also supported.
|
| 98 |
+
|
| 99 |
+
**Key parameters:**
|
| 100 |
+
- `--lr_scheduler.decay_ratio`: The proportion of the steps allocated to the decay phase. The learning rate will remain stable after the warmup period and only start decaying during the last `decay_ratio` portion of the total training steps, which is known as the Warmup-Stable-Decay (WSD) schedule.
|
| 101 |
+
- `--lr_scheduler.warmup_steps`: The number of steps for the learning rate warmup phase.
|
| 102 |
+
- `--training.steps`: Total number of training steps.
|
| 103 |
+
- `--training.batch_size`: Batch size per device, must be 1 if `--training.varlen` is set.
|
| 104 |
+
- `--training.seq_len`: The length of each sequence in the batch, which is concatenated from multiple samples.
|
| 105 |
+
- `--training.context_len`: The max allowed length of a sample. For non-varlen mode, this is equivalent to `seq_len`.
|
| 106 |
+
- `--training.varlen`: Whether to conduct variable-length sequence training.
|
| 107 |
+
- `--training.gradient_accumulation_steps`: Number of gradient accumulation steps.
|
| 108 |
+
|
| 109 |
+
> [!WARNING]
|
| 110 |
+
> The total number of tokens processed per batch, referred to as `global_batch_size`, is calculated as batch_size × gradient_accumulation_steps × num_gpus.
|
| 111 |
+
> Each step processes `global_batch_size * seq_len` tokens.
|
| 112 |
+
> Monitor the value of `global_batch_size`, `warmup_steps`, and `steps` carefully when modifying any of the hyperparameters!
|
| 113 |
+
|
| 114 |
+
For a detailed explanation of all parameters, run:
|
| 115 |
+
|
| 116 |
+
```sh
|
| 117 |
+
bash train.sh -h
|
| 118 |
+
```
|
| 119 |
+
|
| 120 |
+
<details>
|
| 121 |
+
<summary>Usage</summary>
|
| 122 |
+
|
| 123 |
+
```py
|
| 124 |
+
options:
|
| 125 |
+
-h, --help show this help message and exit
|
| 126 |
+
--job.config_file JOB.CONFIG_FILE
|
| 127 |
+
Job config file
|
| 128 |
+
--job.dump_folder JOB.DUMP_FOLDER
|
| 129 |
+
Folder to dump job outputs
|
| 130 |
+
--job.description JOB.DESCRIPTION
|
| 131 |
+
Description of the job
|
| 132 |
+
--job.use_for_integration_test
|
| 133 |
+
Add this config to the integration test suite
|
| 134 |
+
--job.print_args Print the args to terminal
|
| 135 |
+
--model.config MODEL.CONFIG
|
| 136 |
+
Path to the model config
|
| 137 |
+
--model.norm_type MODEL.NORM_TYPE
|
| 138 |
+
Type of layer normalization to use [layernorm,
|
| 139 |
+
np_layernorm, rmsnorm, fused_rmsnorm]
|
| 140 |
+
--model.tokenizer_path MODEL.TOKENIZER_PATH
|
| 141 |
+
Tokenizer path
|
| 142 |
+
--profiling.enable_profiling
|
| 143 |
+
Whether to enable pytorch profiler
|
| 144 |
+
--profiling.save_traces_folder PROFILING.SAVE_TRACES_FOLDER
|
| 145 |
+
Trace files location
|
| 146 |
+
--profiling.profile_freq PROFILING.PROFILE_FREQ
|
| 147 |
+
How often to collect profiler traces, in iterations
|
| 148 |
+
--profiling.enable_memory_snapshot
|
| 149 |
+
Whether to dump memory snapshot
|
| 150 |
+
--profiling.save_memory_snapshot_folder PROFILING.SAVE_MEMORY_SNAPSHOT_FOLDER
|
| 151 |
+
Memeory snapshot files location
|
| 152 |
+
--optimizer.name OPTIMIZER.NAME
|
| 153 |
+
Optimizer to use
|
| 154 |
+
--optimizer.eps OPTIMIZER.EPS
|
| 155 |
+
Epsilon value for the optimizer.
|
| 156 |
+
--optimizer.fused Whether the fused implementation(CUDA only) is used.
|
| 157 |
+
--optimizer.scheduler {wsd,cosine,linear}
|
| 158 |
+
Scheduler to use. Currently supported: wsd, cosine,
|
| 159 |
+
and linear.
|
| 160 |
+
--optimizer.lr OPTIMIZER.LR
|
| 161 |
+
Learning rate to use
|
| 162 |
+
--optimizer.min_lr_ratio OPTIMIZER.MIN_LR_RATIO
|
| 163 |
+
Min lr ratio for lr scheduler
|
| 164 |
+
--optimizer.early_step_in_backward
|
| 165 |
+
Whether to apply optimizer in the backward. Caution,
|
| 166 |
+
optimizer_in_backward is not compatible with gradients
|
| 167 |
+
clipping, users should not call
|
| 168 |
+
register_post_accumulate_grad_hook after the optimizer
|
| 169 |
+
is built.
|
| 170 |
+
--training.batch_size TRAINING.BATCH_SIZE
|
| 171 |
+
Batch size
|
| 172 |
+
--training.seq_len TRAINING.SEQ_LEN
|
| 173 |
+
Sequence length
|
| 174 |
+
--training.context_len TRAINING.CONTEXT_LEN
|
| 175 |
+
Max length allowed for each sequence
|
| 176 |
+
--training.varlen Whether to take sequences of variable length as input
|
| 177 |
+
--training.warmup_steps TRAINING.WARMUP_STEPS
|
| 178 |
+
Steps for lr scheduler warmup, normally 1/5 of
|
| 179 |
+
--training.steps
|
| 180 |
+
--training.gradient_accumulation_steps TRAINING.GRADIENT_ACCUMULATION_STEPS
|
| 181 |
+
Number of steps to accumulate gradients before
|
| 182 |
+
updating parameters
|
| 183 |
+
--training.steps TRAINING.STEPS
|
| 184 |
+
How many train steps to run
|
| 185 |
+
--training.max_norm TRAINING.MAX_NORM
|
| 186 |
+
Max norm for gradient clipping
|
| 187 |
+
--training.skip_nan_inf
|
| 188 |
+
Skip batch updates when NaN or INF gradients are
|
| 189 |
+
encountered during training
|
| 190 |
+
--training.dataset TRAINING.DATASET
|
| 191 |
+
Dataset to use, with comma separated values
|
| 192 |
+
--training.dataset_name TRAINING.DATASET_NAME
|
| 193 |
+
The name of the dataset config, with comma separated
|
| 194 |
+
values if provided
|
| 195 |
+
--training.dataset_split TRAINING.DATASET_SPLIT
|
| 196 |
+
Dataset split to use, with comma separated values if
|
| 197 |
+
provided
|
| 198 |
+
--training.data_dir TRAINING.DATA_DIR
|
| 199 |
+
Data dirs to use, with comma separated values if
|
| 200 |
+
provided
|
| 201 |
+
--training.data_files TRAINING.DATA_FILES
|
| 202 |
+
Data files to use, with comma separated values if
|
| 203 |
+
provided
|
| 204 |
+
--training.data_probs TRAINING.DATA_PROBS
|
| 205 |
+
Data sampling probabilities, with comma separated
|
| 206 |
+
values if provided
|
| 207 |
+
--training.streaming Whether to load dataset in streaming mode, used for
|
| 208 |
+
huge dataset
|
| 209 |
+
--training.num_workers TRAINING.NUM_WORKERS
|
| 210 |
+
Number of subprocesses to use for data loading. 0
|
| 211 |
+
means that the data will be loaded in the main
|
| 212 |
+
process.
|
| 213 |
+
--training.prefetch_factor TRAINING.PREFETCH_FACTOR
|
| 214 |
+
Number of batches loaded in advance by each worker.2
|
| 215 |
+
means there will be a total of 2 * num_workers batches
|
| 216 |
+
prefetched across all workers.
|
| 217 |
+
--training.data_parallel_replicate_degree TRAINING.DATA_PARALLEL_REPLICATE_DEGREE
|
| 218 |
+
The `data_parallel_replicate_degree` argument
|
| 219 |
+
specifies the degree of data parallelism for weight
|
| 220 |
+
replication. When this value is greater than 1,
|
| 221 |
+
weights will be replicated across
|
| 222 |
+
`data_parallel_replicate_degree` ranks. If
|
| 223 |
+
`data_parallel_shard_degree` is also greater than 1,
|
| 224 |
+
the parallelism method used is HSDP (Hybrid Sharded
|
| 225 |
+
Data Parallelism). Otherwise, the parallelism method
|
| 226 |
+
used is DDP (Distributed Data Parallelism). 1 means
|
| 227 |
+
disabled.
|
| 228 |
+
--training.data_parallel_shard_degree TRAINING.DATA_PARALLEL_SHARD_DEGREE
|
| 229 |
+
The `data_parallel_shard_degree` argument specifies
|
| 230 |
+
the degree of data parallelism for weight sharding.
|
| 231 |
+
When this value is greater than 1, weights will be
|
| 232 |
+
sharded across `data_parallel_shard_degree` ranks. If
|
| 233 |
+
`data_parallel_replicate_degree` is also greater than
|
| 234 |
+
1, the parallelism method used is HSDP (Hybrid Sharded
|
| 235 |
+
Data Parallelism). Otherwise, the parallelism method
|
| 236 |
+
used is FSDP (Fully Sharded Data Parallelism). -1
|
| 237 |
+
means leftover ranks will be used (After
|
| 238 |
+
DP_REPLICATE/SP/PP). Note that only
|
| 239 |
+
`data_parallel_shard_degree` can be negative. 1 means
|
| 240 |
+
disabled.
|
| 241 |
+
--training.enable_cpu_offload
|
| 242 |
+
Whether to apply CPU offloading of parameters,
|
| 243 |
+
gradients, and optimizer states in FSDP
|
| 244 |
+
--training.tensor_parallel_degree TRAINING.TENSOR_PARALLEL_DEGREE
|
| 245 |
+
Tensor Parallelism degree. 1 means disabled.
|
| 246 |
+
--training.disable_loss_parallel
|
| 247 |
+
Whether to apply loss parallel when sequence parallel
|
| 248 |
+
is enabled
|
| 249 |
+
--training.mixed_precision_param {bfloat16,float32}
|
| 250 |
+
torch dtype to use for parameters when applying mixed
|
| 251 |
+
precision via FSDP. This feature only takes effect
|
| 252 |
+
when data_parallel_shard_degree > 1
|
| 253 |
+
--training.mixed_precision_reduce {float32}
|
| 254 |
+
torch dtype to use for reductions when applying mixed
|
| 255 |
+
precision via FSDP. This feature only takes effect
|
| 256 |
+
when data_parallel_shard_degree > 1
|
| 257 |
+
--training.compile Whether to compile the model
|
| 258 |
+
--training.gc_freq TRAINING.GC_FREQ
|
| 259 |
+
Python garbage control scheduling interval, in steps
|
| 260 |
+
--training.seed TRAINING.SEED
|
| 261 |
+
Choose the base RNG seed used for training
|
| 262 |
+
--training.deterministic
|
| 263 |
+
Use deterministic algorithms wherever possible, may be
|
| 264 |
+
slower
|
| 265 |
+
--metrics.log_freq METRICS.LOG_FREQ
|
| 266 |
+
How often to log metrics to TensorBoard, in iterations
|
| 267 |
+
--metrics.enable_tensorboard
|
| 268 |
+
Whether to log metrics to TensorBoard
|
| 269 |
+
--metrics.disable_color_printing
|
| 270 |
+
Whether to disable color printing in logs
|
| 271 |
+
--metrics.save_tb_folder METRICS.SAVE_TB_FOLDER
|
| 272 |
+
Folder to dump TensorBoard states
|
| 273 |
+
--metrics.rank_0_only
|
| 274 |
+
Whether to save TensorBoard metrics only for rank 0 or
|
| 275 |
+
for all ranks. When pipeline_parallel_degree is > 1,
|
| 276 |
+
this option uses the 0th rank of the last stage
|
| 277 |
+
pipeline group, which is the only stage that computes
|
| 278 |
+
loss metrics.
|
| 279 |
+
--metrics.enable_wandb
|
| 280 |
+
Whether to log metrics to Weights & Biases
|
| 281 |
+
--experimental.enable_async_tensor_parallel
|
| 282 |
+
Whether to apply async tensor parallel (currently only
|
| 283 |
+
effective when compile is enabled)
|
| 284 |
+
--experimental.pipeline_parallel_degree EXPERIMENTAL.PIPELINE_PARALLEL_DEGREE
|
| 285 |
+
Pipeline Parallelism degree, or number of ranks. 1
|
| 286 |
+
means disabled. If using looped schedules, this still
|
| 287 |
+
specifies the number of physical ranks, not the number
|
| 288 |
+
of stages. Stages per rank are inferred from split
|
| 289 |
+
points degree, and schedule.
|
| 290 |
+
--experimental.pipeline_parallel_split_points EXPERIMENTAL.PIPELINE_PARALLEL_SPLIT_POINTS [EXPERIMENTAL.PIPELINE_PARALLEL_SPLIT_POINTS ...]
|
| 291 |
+
Specify comma-separated names of modules to use as the
|
| 292 |
+
beginning of a split point. e.g. "layers.0,layers.2"
|
| 293 |
+
will cause the model to be split into 3 stages, the
|
| 294 |
+
first containing all the layers up to layers.0, the
|
| 295 |
+
second containing layers.0 and up to layers.2, the
|
| 296 |
+
third containing layers.2 and all the remaining
|
| 297 |
+
layers. Note: fully-automated splitting may be enabled
|
| 298 |
+
in the future, but currently the split points must be
|
| 299 |
+
specified manually.
|
| 300 |
+
--experimental.pipeline_parallel_schedule EXPERIMENTAL.PIPELINE_PARALLEL_SCHEDULE
|
| 301 |
+
Specify the Pipeline Parallel schedule to use. The
|
| 302 |
+
supported schedules are: https://github.com/pytorch/py
|
| 303 |
+
torch/blob/de4c2a3b4e89d96334dc678d1c3f2ae51a6630a0/to
|
| 304 |
+
rch/distributed/pipelining/schedules.py#L2161. The
|
| 305 |
+
schedule must be compatible with the split points and
|
| 306 |
+
stages_per_rank. Looped schedules (e.g.
|
| 307 |
+
Interleaved1F1B) require specifying
|
| 308 |
+
pipeline_parallel_degree = number of ranks, and
|
| 309 |
+
split_points = number of stages - 1
|
| 310 |
+
--experimental.pipeline_parallel_schedule_csv EXPERIMENTAL.PIPELINE_PARALLEL_SCHEDULE_CSV
|
| 311 |
+
Specify the path to the pipeline parallel schedule csv
|
| 312 |
+
file to use. The pipeline_parallel_schedule argument
|
| 313 |
+
must be either PipelineScheduleSingle,
|
| 314 |
+
PipelineScheduleMulti, or _PipelineScheduleRuntime.
|
| 315 |
+
--experimental.pipeline_parallel_microbatches EXPERIMENTAL.PIPELINE_PARALLEL_MICROBATCHES
|
| 316 |
+
How many microbatches to split the global training
|
| 317 |
+
batch into when using pipeline parallelism. The global
|
| 318 |
+
training batch size must be evenly divisible by the
|
| 319 |
+
number of microbatches. The default value will be the
|
| 320 |
+
number of pipeline stages, if unspecified.
|
| 321 |
+
--experimental.enable_compiled_autograd
|
| 322 |
+
Enable CompiledAutograd to compile the backward.
|
| 323 |
+
--experimental.context_parallel_degree EXPERIMENTAL.CONTEXT_PARALLEL_DEGREE
|
| 324 |
+
Context parallelism degree. 1 means disabled.
|
| 325 |
+
--experimental.context_parallel_rotate_method EXPERIMENTAL.CONTEXT_PARALLEL_ROTATE_METHOD
|
| 326 |
+
The collective to use in context parallel SDPA for kv
|
| 327 |
+
shards exchange. 'allgather' means to all-gather all
|
| 328 |
+
kv shards on ranks after the first sub-SDPA
|
| 329 |
+
computation, 'alltoall' means to all-to-all shuffle
|
| 330 |
+
the kv shards. The default value is 'allgather'.
|
| 331 |
+
--checkpoint.enable_checkpoint
|
| 332 |
+
Whether to enable checkpoint
|
| 333 |
+
--checkpoint.folder CHECKPOINT.FOLDER
|
| 334 |
+
The folder to store the checkpoints. When
|
| 335 |
+
enable_checkpoint is set to true, checkpoints will be
|
| 336 |
+
in {--job.dump_folder}/{--checkpoint.folder}.
|
| 337 |
+
--checkpoint.interval_type CHECKPOINT.INTERVAL_TYPE
|
| 338 |
+
Checkpointing interval unit of measurement ['step',
|
| 339 |
+
'seconds']
|
| 340 |
+
--checkpoint.interval CHECKPOINT.INTERVAL
|
| 341 |
+
Checkpointing interval, in steps or seconds depending
|
| 342 |
+
on --checkpoint.interval_type
|
| 343 |
+
--checkpoint.model_weights_only
|
| 344 |
+
When model_weights_only=True, only model weights will
|
| 345 |
+
be saved at the end of training. With this,
|
| 346 |
+
checkpoints can be loaded using `torch.load(...,
|
| 347 |
+
weights_only=True)` after conversion. When
|
| 348 |
+
model_weights_only=False, the full checkpoint will be
|
| 349 |
+
saved. A full checkpoint includes model, optimizer and
|
| 350 |
+
train_state, which can be used to resume training. The
|
| 351 |
+
default value is false.
|
| 352 |
+
--checkpoint.export_dtype {float16,bfloat16,float32}
|
| 353 |
+
Converts to the specified precision when training
|
| 354 |
+
completes and model_weights_only=true. Currently
|
| 355 |
+
supports float32, float16, and bfloat16. The default
|
| 356 |
+
value is float32.
|
| 357 |
+
--checkpoint.create_seed_checkpoint
|
| 358 |
+
Initializes the full model without applying
|
| 359 |
+
parallelisms, and then saves it as a seed checkpoint.
|
| 360 |
+
Note: requires user to call train.py without
|
| 361 |
+
specifying any parallelisms, e.g. NGPU=1. Could be
|
| 362 |
+
implemented as a separate script, but this way shares
|
| 363 |
+
more code.
|
| 364 |
+
--checkpoint.async_mode CHECKPOINT.ASYNC_MODE
|
| 365 |
+
Which async checkpoint mode to use. Currently there
|
| 366 |
+
are 3 different modes. 1. "disabled": synchronized
|
| 367 |
+
checkpointing will be used. 2. "async":
|
| 368 |
+
torch.distributed.checkpoint.async_save will be used.
|
| 369 |
+
1. "async_with_pinned_mem": this option utilizes a
|
| 370 |
+
dedicated pinned memory space and creates a separate
|
| 371 |
+
process for faster GPU->CPU transfer performance and
|
| 372 |
+
eliminating GIL contention. The cost is increased CPU
|
| 373 |
+
memory usage. If insufficient CPU memory is available,
|
| 374 |
+
performance may degrade due to memory paging. For most
|
| 375 |
+
users, "async" should suffice as the performance
|
| 376 |
+
overhead is typically small (on the order of tens of
|
| 377 |
+
seconds) compared to checkpointing frequency. This
|
| 378 |
+
mode can be employed to pursue near-zero checkpointing
|
| 379 |
+
times (e.g., < 1 second) given appropriate hardware
|
| 380 |
+
support such as ample CPU memory and fast PCIe.
|
| 381 |
+
"disabled" is the default mode.
|
| 382 |
+
--checkpoint.keep_latest_k CHECKPOINT.KEEP_LATEST_K
|
| 383 |
+
Keeps only the latest k checkpoints, and purging older
|
| 384 |
+
ones. If 0, keep all checkpoints. 0 is the default
|
| 385 |
+
value.
|
| 386 |
+
--checkpoint.load_step CHECKPOINT.LOAD_STEP
|
| 387 |
+
Load the checkpoint at the specified step. If -1, load
|
| 388 |
+
the latest checkpoint.
|
| 389 |
+
--float8.enable_float8_linear
|
| 390 |
+
If true, swaps `torch.nn.Linear` with `Float8Linear`.
|
| 391 |
+
This feature requires you to install 'torchao' which
|
| 392 |
+
can be found here: https://github.com/pytorch/ao
|
| 393 |
+
--float8.enable_fsdp_float8_all_gather
|
| 394 |
+
Whether enable float8 all-gather in FSDP
|
| 395 |
+
--float8.precompute_float8_dynamic_scale_for_fsdp
|
| 396 |
+
Whether precompute float8 scales dynamically for FSDP
|
| 397 |
+
--float8.scaling_type_input {dynamic,delayed}
|
| 398 |
+
float8 scaling for input, dynamic (default) or delayed
|
| 399 |
+
--float8.scaling_type_weight FLOAT8.SCALING_TYPE_WEIGHT
|
| 400 |
+
float8 scaling for input, dynamic (default) or delayed
|
| 401 |
+
--float8.scaling_type_grad_output FLOAT8.SCALING_TYPE_GRAD_OUTPUT
|
| 402 |
+
float8 scaling for input, dynamic (default) or delayed
|
| 403 |
+
--comm.init_timeout_seconds COMM.INIT_TIMEOUT_SECONDS
|
| 404 |
+
Timeout for communication operations, during
|
| 405 |
+
initialization and first train step.
|
| 406 |
+
--comm.train_timeout_seconds COMM.TRAIN_TIMEOUT_SECONDS
|
| 407 |
+
Timeout for communication operations after the first
|
| 408 |
+
train step -- usually a tighter bound than during
|
| 409 |
+
initialization.
|
| 410 |
+
--comm.trace_buf_size COMM.TRACE_BUF_SIZE
|
| 411 |
+
Flight recorder ring buffer size, >0 means recording
|
| 412 |
+
by default, 0 means disabled
|
| 413 |
+
--memory_estimation.enabled
|
| 414 |
+
Whether to estimate memory usage for FSDP
|
| 415 |
+
--memory_estimation.disable_fake_mode
|
| 416 |
+
Whether to estimate memory under FakeTensorMode
|
| 417 |
+
```
|
| 418 |
+
</details>
|
| 419 |
+
|
| 420 |
+
### Training with variable-length inputs
|
| 421 |
+
When you set the `--training.varlen` flag, you're enabling a more efficient training method that packs multiple documents together into a single long sequence, eliminating the need for padding.
|
| 422 |
+
This is particularly useful when your dataset contains documents of varying lengths.
|
| 423 |
+
Let's break down how `--training.seq_len` and `--training.context_len` work in this mode.
|
| 424 |
+
|
| 425 |
+
* `--training.seq_len` (Packed Sequence Length): This is the total length of the final sequence fed to the model on one device. Instead of processing one document at a time, the dataloader takes multiple documents (each split to sequences no longer than `context_len`), concatenates them end-to-end, and creates a single long sequence of length `seq_len`.
|
| 426 |
+
* `--training.context_len` (Sample Length): This parameter defines the maximum number of tokens for a single document or sample. If a document from the dataset is longer than `context_len`, it will be truncated. For example, if `--training.context_len` is set to 4,096, a document with 5,000 tokens will be cut down to its first 4,096 tokens, leaving the left tokens as another independent sequence, while a document with 3000 tokens remains unchanged.
|
| 427 |
+
|
| 428 |
+
### Training with `torch.compile`
|
| 429 |
+
|
| 430 |
+
Starting from `torch 2.0`, `torch.compile` has been introduced as a new feature to seamlessly accelerate training processes.
|
| 431 |
+
In `flame`, one can simply enable `torch.compile` by adding `--training.compile` flag to your training script.
|
| 432 |
+
|
| 433 |
+
However, `fla` has integrated numerous fused kernels for acceleration, which may potentially conflict with `torch.compile`.
|
| 434 |
+
We are actively working on resolving these issues to make compilation transparent to users.
|
| 435 |
+
In the meantime, please ensure you are using the latest dependencies.
|
| 436 |
+
|
| 437 |
+
Specifically, **we recommend using `torch>=2.6` and `triton>=3.0`**.
|
| 438 |
+
|
| 439 |
+
### Training with multiple datasets
|
| 440 |
+
|
| 441 |
+
If you wish to train a model with all-round capabilities (e.g., code, math, and multilingual ability), it's necessary to train on multiple datasets.
|
| 442 |
+
`flame` allows training with multiple datasets easily.
|
| 443 |
+
For example, you can specify the following arguments to train on 6 datasets with different proportions:
|
| 444 |
+
|
| 445 |
+
```sh
|
| 446 |
+
--training.dataset HuggingFaceFW/fineweb-edu,opencsg/Fineweb-Edu-Chinese-V2.1,OpenCoder-LLM/opc-fineweb-code-corpus,math-ai/AutoMathText,EleutherAI/proof-pile-2,OpenCoder-LLM/opc-fineweb-math-corpus \
|
| 447 |
+
--training.data_probs 0.6,0.15,0.15,0.014,0.058,0.028 \
|
| 448 |
+
```
|
| 449 |
+
|
| 450 |
+
### ~Finalizing training~
|
| 451 |
+
|
| 452 |
+
> [!NOTE]
|
| 453 |
+
> We have done this conversion automatically in the training script since our latest updates.
|
| 454 |
+
|
| 455 |
+
Once training is complete, you may want to convert the distributed checkpoints (DCPs) into the 🤗 format for broader use.
|
| 456 |
+
To facilitate this, we provide a straightforward conversion script:
|
| 457 |
+
|
| 458 |
+
```sh
|
| 459 |
+
python -m flame.utils.convert_dcp_to_hf --path <path_to_model> --step <step> --config <path_to_config> --tokenizer <path_to_tokenizer>
|
| 460 |
+
```
|
| 461 |
+
After this, your model will be in the 🤗 format, ready to be shared or deployed.
|
| 462 |
+
You can then easily publish your model using the `huggingface_hub` for wider accessibility.
|
| 463 |
+
|
| 464 |
+
### Continual training
|
| 465 |
+
|
| 466 |
+
If you wish to build upon a strong pre-trained model (in 🤗 format) and continue training, we also offer a script to convert the 🤗 format model back into DCP format.
|
| 467 |
+
This allows you to seamlessly resume training with `flame`.
|
| 468 |
+
```sh
|
| 469 |
+
python -m flame.utils.convert_hf_to_dcp --model <path_to_hf> --checkpoint <path_to_dcp/checkpoint/step-0>
|
| 470 |
+
```
|
| 471 |
+
Here, `<path_to_dcp>` is the directory where your distributed checkpoints will be stored.
|
| 472 |
+
The checkpoint is intentionally saved at `<step-0>` within the checkpoint folder to ensure it is loadable by `flame` during the initial training step, similar to how a seed checkpoint is handled.
|
| 473 |
+
|
| 474 |
+
Once the conversion is complete, you can proceed with training using `flame` as usual, continuing from where the pretrained model left off.
|
| 475 |
+
|
| 476 |
+
## Multi-node training
|
| 477 |
+
|
| 478 |
+
If you have access to multi-node GPUs, consider leveraging them for optimal performance.
|
| 479 |
+
This process is straightforward and well-documented in the PyTorch [docs](https://pytorch.org/docs/stable/elastic/run.html).
|
| 480 |
+
|
| 481 |
+
To set up multi-node training:
|
| 482 |
+
* Set the environment variables `MASTER_ADDR=<ip>` and `MASTER_PORT=<port>` before running the training script across all nodes.
|
| 483 |
+
* If you're using a job scheduler like Slurm, it will handle these variables for you.
|
| 484 |
+
|
| 485 |
+
`torchtitan` provides a [Slurm script](https://github.com/pytorch/torchtitan/blob/main/multinode_trainer.slurm) for multi-node training, which you can use as a reference or starting point.
|
| 486 |
+
|
| 487 |
+
## Custom models
|
| 488 |
+
|
| 489 |
+
`flame` supports custom model architectures through seamless integration with the Hugging Face `transformers` library. To add your own model:
|
| 490 |
+
|
| 491 |
+
1. Create a new model directory under `custom_models/` (see `custom_models/sba` for a complete example)
|
| 492 |
+
2. Implement your model classes and configuration:
|
| 493 |
+
- Define a config class inheriting from `PretrainedConfig` (see `custom_models/sba/config_sba.py` for an example)
|
| 494 |
+
- Create model classes inheriting from `PreTrainedModel` (see `custom_models/sba/modeling_sba.py` for an example)
|
| 495 |
+
3. Register your models in `__init__.py`:
|
| 496 |
+
- Import your model classes and config classes
|
| 497 |
+
- Register your models with the `AutoModelForCausalLM`, `AutoModel` and `AutoConfig` classes (see `custom_models/sba/__init__.py` for an example)
|
| 498 |
+
4. Create a config file for your custom model, just need to specify the `model_type` to the one you just named for your custom model (example: `configs/sba_340m.json`).
|
| 499 |
+
5. Training is extremely simple, you can just use the `flame.train.py` script to train your custom model.
|
| 500 |
+
|
| 501 |
+
|
| 502 |
+
|
| 503 |
+
|
| 504 |
+
|
| 505 |
+
|
| 506 |
+
|
| 507 |
+
## Citation
|
| 508 |
+
|
| 509 |
+
If you find `flame` helpful for your work, please consider citing it.
|
| 510 |
+
|
| 511 |
+
```bib
|
| 512 |
+
@software{yang2025flame,
|
| 513 |
+
title = {Flame: Flash Language Modeling Made Easy},
|
| 514 |
+
author = {Zhang, Yu and Yang, Songlin},
|
| 515 |
+
url = {https://github.com/fla-org/flame},
|
| 516 |
+
month = jan,
|
| 517 |
+
year = {2025}
|
| 518 |
+
}
|
| 519 |
+
```
|
transformer_pp_1b_c4/transformer_pp_1b_c4_valc4_soap_pdim2048_pfreq10_lr3e_3_b1_0_9_b2_0_95_eps_1e_15_20260508_191338/exp_data/checkpoint/step-30720/.metadata
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c8eeb4fb92a760e8ebd173686dafcb1195996857abe30e2234f4d72008e5d7a2
|
| 3 |
+
size 1411174
|
transformer_pp_1b_c4/transformer_pp_1b_c4_valc4_soap_pdim2048_pfreq10_lr3e_3_b1_0_9_b2_0_95_eps_1e_15_20260508_191338/exp_data/checkpoint/step-30720/__0_0.distcp
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:10974149c457b268f3ad996002bd1b8d2a8baec49773c19fbf6e81f464ea1746
|
| 3 |
+
size 2001640303
|
transformer_pp_1b_c4/transformer_pp_1b_c4_valc4_soap_pdim2048_pfreq10_lr3e_3_b1_0_9_b2_0_95_eps_1e_15_20260508_191338/exp_data/checkpoint/step-30720/__1_0.distcp
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:b6ddbd0ab3bd1163f7e7822c38c3a15a941cc9ad40d47e205cddae4c4a5f34e6
|
| 3 |
+
size 1995395505
|
transformer_pp_1b_c4/transformer_pp_1b_c4_valc4_soap_pdim2048_pfreq10_lr3e_3_b1_0_9_b2_0_95_eps_1e_15_20260508_191338/exp_data/checkpoint/step-30720/__2_0.distcp
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:23b6ad567d472445d5012f0a9e364b1fd43f8c0ebca15236912a0f7355ee38bb
|
| 3 |
+
size 2001569722
|
transformer_pp_1b_c4/transformer_pp_1b_c4_valc4_soap_pdim2048_pfreq10_lr3e_3_b1_0_9_b2_0_95_eps_1e_15_20260508_191338/exp_data/checkpoint/step-30720/__3_0.distcp
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:3d3a1a15c84b55588b149230697334d3a99ab4a0d7ad238791798bbdbe14f80a
|
| 3 |
+
size 1992747645
|
transformer_pp_1b_c4/transformer_pp_1b_c4_valc4_soap_pdim2048_pfreq10_lr3e_3_b1_0_9_b2_0_95_eps_1e_15_20260508_191338/exp_data/checkpoint/step-30720/__4_0.distcp
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:488ecb9a14890d61c18143655835ce910b20c188d7c3fd670f911d9c545927ac
|
| 3 |
+
size 2217877888
|
transformer_pp_1b_c4/transformer_pp_1b_c4_valc4_soap_pdim2048_pfreq10_lr3e_3_b1_0_9_b2_0_95_eps_1e_15_20260508_191338/exp_data/checkpoint/step-30720/__5_0.distcp
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:998721db82dda199a9f0eb5fccd0984eed4f1856487fc30e607dde94f035c5c4
|
| 3 |
+
size 2217843999
|
transformer_pp_1b_c4/transformer_pp_1b_c4_valc4_soap_pdim2048_pfreq10_lr3e_3_b1_0_9_b2_0_95_eps_1e_15_20260508_191338/exp_data/checkpoint/step-30720/__6_0.distcp
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:864911064ddd46130fe60104a1ab977dc785989309cdc75b787c1351d5a9d725
|
| 3 |
+
size 1991234513
|
transformer_pp_1b_c4/transformer_pp_1b_c4_valc4_soap_pdim2048_pfreq10_lr3e_3_b1_0_9_b2_0_95_eps_1e_15_20260508_191338/exp_data/checkpoint/step-30720/__7_0.distcp
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:7cb288436d4b820c78953e25bf2e8927315e16a284b82cfaccefa5b9b711ece9
|
| 3 |
+
size 1996738056
|
transformer_pp_1b_c4/transformer_pp_1b_c4_valc4_soap_pdim2048_pfreq10_lr3e_3_b1_0_9_b2_0_95_eps_1e_15_20260508_191338/exp_data/config.json
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"architectures": [
|
| 3 |
+
"TransformerForCausalLM"
|
| 4 |
+
],
|
| 5 |
+
"bos_token_id": 1,
|
| 6 |
+
"dtype": "float32",
|
| 7 |
+
"elementwise_affine": true,
|
| 8 |
+
"eos_token_id": 2,
|
| 9 |
+
"fuse_cross_entropy": true,
|
| 10 |
+
"fuse_linear_cross_entropy": false,
|
| 11 |
+
"fuse_norm": true,
|
| 12 |
+
"fuse_swiglu": true,
|
| 13 |
+
"hidden_act": "swish",
|
| 14 |
+
"hidden_ratio": 4,
|
| 15 |
+
"hidden_size": 2048,
|
| 16 |
+
"initializer_range": 0.02,
|
| 17 |
+
"intermediate_size": null,
|
| 18 |
+
"max_position_embeddings": 8192,
|
| 19 |
+
"model_type": "transformer",
|
| 20 |
+
"norm_eps": 1e-06,
|
| 21 |
+
"num_heads": 32,
|
| 22 |
+
"num_hidden_layers": 24,
|
| 23 |
+
"num_kv_heads": null,
|
| 24 |
+
"pad_token_id": 2,
|
| 25 |
+
"qk_norm": false,
|
| 26 |
+
"qkv_bias": false,
|
| 27 |
+
"rope_theta": 10000.0,
|
| 28 |
+
"tie_word_embeddings": false,
|
| 29 |
+
"transformers_version": "4.57.6",
|
| 30 |
+
"use_cache": true,
|
| 31 |
+
"use_l2warp": false,
|
| 32 |
+
"vocab_size": 32000,
|
| 33 |
+
"window_size": null
|
| 34 |
+
}
|
transformer_pp_1b_c4/transformer_pp_1b_c4_valc4_soap_pdim2048_pfreq10_lr3e_3_b1_0_9_b2_0_95_eps_1e_15_20260508_191338/exp_data/configs/blt_transformer_1000hash.json
ADDED
|
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"model_type": "blt",
|
| 3 |
+
"vocab_size": 260,
|
| 4 |
+
"max_position_embeddings": 4096,
|
| 5 |
+
"initializer_range": 0.02,
|
| 6 |
+
"tie_word_embeddings": false,
|
| 7 |
+
"patch_in_forward": true,
|
| 8 |
+
"patch_size": 4,
|
| 9 |
+
"patching_mode": "entropy",
|
| 10 |
+
"patching_threshold": 1.335442066192627,
|
| 11 |
+
"patching_batch_size": 1,
|
| 12 |
+
"max_patch_length": null,
|
| 13 |
+
"patching_device": "cuda",
|
| 14 |
+
"realtime_patching": true,
|
| 15 |
+
"patching_threshold_add": null,
|
| 16 |
+
"monotonicity": false,
|
| 17 |
+
"cross_attn_k": 2,
|
| 18 |
+
"encoder_hash_byte_group_size": [3, 4, 5, 6, 7, 8],
|
| 19 |
+
"encoder_hash_byte_group_vocab": 1000,
|
| 20 |
+
"encoder_hash_byte_group_nb_functions": 1,
|
| 21 |
+
"patcher_config": {
|
| 22 |
+
"model_type": "blt_patcher",
|
| 23 |
+
"vocab_size": 260,
|
| 24 |
+
"hidden_size": 512,
|
| 25 |
+
"num_hidden_layers": 7,
|
| 26 |
+
"num_attention_heads": 8,
|
| 27 |
+
"num_key_value_heads": 8,
|
| 28 |
+
"max_position_embeddings": 8192,
|
| 29 |
+
"rms_norm_eps": 1e-5,
|
| 30 |
+
"dropout": 0.0,
|
| 31 |
+
"intermediate_size": 1365,
|
| 32 |
+
"hidden_act": "silu",
|
| 33 |
+
"initializer_range": 0.02,
|
| 34 |
+
"rope_parameters": {"rope_type": "default",
|
| 35 |
+
"rope_theta": 500000
|
| 36 |
+
}
|
| 37 |
+
},
|
| 38 |
+
"encoder_config": {
|
| 39 |
+
"model_type": "blt_local_encoder",
|
| 40 |
+
"vocab_size": 260,
|
| 41 |
+
"hidden_size": 512,
|
| 42 |
+
"hidden_size_global": 1024,
|
| 43 |
+
"num_hidden_layers": 1,
|
| 44 |
+
"num_attention_heads": 8,
|
| 45 |
+
"num_key_value_heads": 8,
|
| 46 |
+
"head_dim": 64,
|
| 47 |
+
"intermediate_size": 1365,
|
| 48 |
+
"rms_norm_eps": 1e-5,
|
| 49 |
+
"dropout": 0.0,
|
| 50 |
+
"max_position_embeddings": 24576,
|
| 51 |
+
"cross_attn_all_layers": false,
|
| 52 |
+
"cross_attn_k": 2,
|
| 53 |
+
"hidden_act": "silu",
|
| 54 |
+
"initializer_range": 0.02,
|
| 55 |
+
"rope_parameters": {"rope_type": "default",
|
| 56 |
+
"rope_theta": 500000
|
| 57 |
+
}
|
| 58 |
+
},
|
| 59 |
+
"decoder_config": {
|
| 60 |
+
"model_type": "blt_local_decoder",
|
| 61 |
+
"vocab_size": 260,
|
| 62 |
+
"hidden_size": 512,
|
| 63 |
+
"hidden_size_global": 1024,
|
| 64 |
+
"num_hidden_layers": 9,
|
| 65 |
+
"num_attention_heads": 8,
|
| 66 |
+
"num_key_value_heads": 8,
|
| 67 |
+
"head_dim": 64,
|
| 68 |
+
"intermediate_size": 1365,
|
| 69 |
+
"rms_norm_eps": 1e-5,
|
| 70 |
+
"dropout": 0.0,
|
| 71 |
+
"max_position_embeddings": 24576,
|
| 72 |
+
"cross_attn_all_layers": true,
|
| 73 |
+
"cross_attn_k": 2,
|
| 74 |
+
"hidden_act": "silu",
|
| 75 |
+
"initializer_range": 0.02,
|
| 76 |
+
"rope_parameters": {"rope_type": "default",
|
| 77 |
+
"rope_theta": 500000
|
| 78 |
+
}
|
| 79 |
+
},
|
| 80 |
+
"global_config": {
|
| 81 |
+
"model_type": "blt_global_transformer",
|
| 82 |
+
"hidden_size": 1024,
|
| 83 |
+
"num_hidden_layers": 25,
|
| 84 |
+
"num_attention_heads": 8,
|
| 85 |
+
"num_key_value_heads": 8,
|
| 86 |
+
"head_dim": 128,
|
| 87 |
+
"intermediate_size": 2731,
|
| 88 |
+
"rms_norm_eps": 1e-5,
|
| 89 |
+
"dropout": 0.0,
|
| 90 |
+
"max_position_embeddings": 4096,
|
| 91 |
+
"hidden_act": "silu",
|
| 92 |
+
"initializer_range": 0.02,
|
| 93 |
+
"rope_parameters": {"rope_type": "default",
|
| 94 |
+
"rope_theta": 500000
|
| 95 |
+
},
|
| 96 |
+
"encoder_cross_output_size": null
|
| 97 |
+
}
|
| 98 |
+
}
|
transformer_pp_1b_c4/transformer_pp_1b_c4_valc4_soap_pdim2048_pfreq10_lr3e_3_b1_0_9_b2_0_95_eps_1e_15_20260508_191338/exp_data/configs/blt_transformer_1_5B.json
ADDED
|
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"model_type": "blt",
|
| 3 |
+
"vocab_size": 260,
|
| 4 |
+
"max_position_embeddings": 4096,
|
| 5 |
+
"initializer_range": 0.02,
|
| 6 |
+
"tie_word_embeddings": false,
|
| 7 |
+
"patch_in_forward": true,
|
| 8 |
+
"patch_size": 4,
|
| 9 |
+
"patching_mode": "entropy",
|
| 10 |
+
"patching_threshold": 1.335442066192627,
|
| 11 |
+
"patching_batch_size": 1,
|
| 12 |
+
"max_patch_length": null,
|
| 13 |
+
"patching_device": "cuda",
|
| 14 |
+
"realtime_patching": true,
|
| 15 |
+
"patching_threshold_add": null,
|
| 16 |
+
"monotonicity": false,
|
| 17 |
+
"cross_attn_k": 2,
|
| 18 |
+
"encoder_hash_byte_group_size": [3, 4, 5, 6, 7, 8],
|
| 19 |
+
"encoder_hash_byte_group_vocab": 500,
|
| 20 |
+
"encoder_hash_byte_group_nb_functions": 1,
|
| 21 |
+
"patcher_config": {
|
| 22 |
+
"model_type": "blt_patcher",
|
| 23 |
+
"vocab_size": 260,
|
| 24 |
+
"hidden_size": 768,
|
| 25 |
+
"num_hidden_layers": 7,
|
| 26 |
+
"num_attention_heads": 12,
|
| 27 |
+
"num_key_value_heads": 12,
|
| 28 |
+
"max_position_embeddings": 8192,
|
| 29 |
+
"rms_norm_eps": 1e-5,
|
| 30 |
+
"dropout": 0.0,
|
| 31 |
+
"intermediate_size": 2048,
|
| 32 |
+
"hidden_act": "silu",
|
| 33 |
+
"initializer_range": 0.02,
|
| 34 |
+
"rope_parameters": {"rope_type": "default",
|
| 35 |
+
"rope_theta": 500000
|
| 36 |
+
}
|
| 37 |
+
},
|
| 38 |
+
"encoder_config": {
|
| 39 |
+
"model_type": "blt_local_encoder",
|
| 40 |
+
"vocab_size": 260,
|
| 41 |
+
"hidden_size": 1024,
|
| 42 |
+
"hidden_size_global": 2048,
|
| 43 |
+
"num_hidden_layers": 1,
|
| 44 |
+
"num_attention_heads": 16,
|
| 45 |
+
"num_key_value_heads": 16,
|
| 46 |
+
"head_dim": 64,
|
| 47 |
+
"intermediate_size": 2816,
|
| 48 |
+
"rms_norm_eps": 1e-5,
|
| 49 |
+
"dropout": 0.0,
|
| 50 |
+
"max_position_embeddings": 24576,
|
| 51 |
+
"cross_attn_all_layers": false,
|
| 52 |
+
"cross_attn_k": 2,
|
| 53 |
+
"hidden_act": "silu",
|
| 54 |
+
"initializer_range": 0.02,
|
| 55 |
+
"rope_parameters": {"rope_type": "default",
|
| 56 |
+
"rope_theta": 500000
|
| 57 |
+
}
|
| 58 |
+
},
|
| 59 |
+
|
| 60 |
+
"decoder_config": {
|
| 61 |
+
"model_type": "blt_local_decoder",
|
| 62 |
+
"vocab_size": 260,
|
| 63 |
+
"hidden_size": 1024,
|
| 64 |
+
"hidden_size_global": 2048,
|
| 65 |
+
"num_hidden_layers": 9,
|
| 66 |
+
"num_attention_heads": 16,
|
| 67 |
+
"num_key_value_heads": 16,
|
| 68 |
+
"head_dim": 64,
|
| 69 |
+
"intermediate_size": 2816,
|
| 70 |
+
"rms_norm_eps": 1e-5,
|
| 71 |
+
"dropout": 0.0,
|
| 72 |
+
"max_position_embeddings": 24576,
|
| 73 |
+
"cross_attn_all_layers": true,
|
| 74 |
+
"cross_attn_k": 2,
|
| 75 |
+
"hidden_act": "silu",
|
| 76 |
+
"initializer_range": 0.02,
|
| 77 |
+
"rope_parameters": {"rope_type": "default",
|
| 78 |
+
"rope_theta": 500000
|
| 79 |
+
}
|
| 80 |
+
},
|
| 81 |
+
"global_config": {
|
| 82 |
+
"model_type": "blt_global_transformer",
|
| 83 |
+
"hidden_size": 2048,
|
| 84 |
+
"num_hidden_layers": 25,
|
| 85 |
+
"num_attention_heads": 16,
|
| 86 |
+
"num_key_value_heads": 16,
|
| 87 |
+
"head_dim": 128,
|
| 88 |
+
"intermediate_size": 5632,
|
| 89 |
+
"rms_norm_eps": 1e-5,
|
| 90 |
+
"dropout": 0.0,
|
| 91 |
+
"max_position_embeddings": 4096,
|
| 92 |
+
"hidden_act": "silu",
|
| 93 |
+
"initializer_range": 0.02,
|
| 94 |
+
"rope_parameters": {"rope_type": "default",
|
| 95 |
+
"rope_theta": 500000
|
| 96 |
+
},
|
| 97 |
+
"encoder_cross_output_size": null
|
| 98 |
+
}
|
| 99 |
+
}
|
transformer_pp_1b_c4/transformer_pp_1b_c4_valc4_soap_pdim2048_pfreq10_lr3e_3_b1_0_9_b2_0_95_eps_1e_15_20260508_191338/exp_data/configs/blt_transformer_380M.json
ADDED
|
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"model_type": "blt",
|
| 3 |
+
"vocab_size": 260,
|
| 4 |
+
"max_position_embeddings": 4096,
|
| 5 |
+
"initializer_range": 0.02,
|
| 6 |
+
"tie_word_embeddings": false,
|
| 7 |
+
"patch_in_forward": true,
|
| 8 |
+
"patch_size": 4,
|
| 9 |
+
"patching_mode": "entropy",
|
| 10 |
+
"patching_threshold": 1.335442066192627,
|
| 11 |
+
"patching_batch_size": 1,
|
| 12 |
+
"max_patch_length": null,
|
| 13 |
+
"patching_device": "cuda",
|
| 14 |
+
"realtime_patching": true,
|
| 15 |
+
"patching_threshold_add": null,
|
| 16 |
+
"monotonicity": false,
|
| 17 |
+
"cross_attn_k": 2,
|
| 18 |
+
"encoder_hash_byte_group_size": [3, 4, 5, 6, 7, 8],
|
| 19 |
+
"encoder_hash_byte_group_vocab": 500,
|
| 20 |
+
"encoder_hash_byte_group_nb_functions": 1,
|
| 21 |
+
"patcher_config": {
|
| 22 |
+
"model_type": "blt_patcher",
|
| 23 |
+
"vocab_size": 260,
|
| 24 |
+
"hidden_size": 512,
|
| 25 |
+
"num_hidden_layers": 7,
|
| 26 |
+
"num_attention_heads": 8,
|
| 27 |
+
"num_key_value_heads": 8,
|
| 28 |
+
"max_position_embeddings": 8192,
|
| 29 |
+
"rms_norm_eps": 1e-5,
|
| 30 |
+
"dropout": 0.0,
|
| 31 |
+
"intermediate_size": 1365,
|
| 32 |
+
"hidden_act": "silu",
|
| 33 |
+
"initializer_range": 0.02,
|
| 34 |
+
"rope_parameters": {"rope_type": "default",
|
| 35 |
+
"rope_theta": 500000
|
| 36 |
+
}
|
| 37 |
+
},
|
| 38 |
+
"encoder_config": {
|
| 39 |
+
"model_type": "blt_local_encoder",
|
| 40 |
+
"vocab_size": 260,
|
| 41 |
+
"hidden_size": 512,
|
| 42 |
+
"hidden_size_global": 1024,
|
| 43 |
+
"num_hidden_layers": 1,
|
| 44 |
+
"num_attention_heads": 8,
|
| 45 |
+
"num_key_value_heads": 8,
|
| 46 |
+
"head_dim": 64,
|
| 47 |
+
"intermediate_size": 1365,
|
| 48 |
+
"rms_norm_eps": 1e-5,
|
| 49 |
+
"dropout": 0.0,
|
| 50 |
+
"max_position_embeddings": 24576,
|
| 51 |
+
"cross_attn_all_layers": false,
|
| 52 |
+
"cross_attn_k": 2,
|
| 53 |
+
"hidden_act": "silu",
|
| 54 |
+
"initializer_range": 0.02,
|
| 55 |
+
"rope_parameters": {"rope_type": "default",
|
| 56 |
+
"rope_theta": 500000
|
| 57 |
+
}
|
| 58 |
+
},
|
| 59 |
+
"decoder_config": {
|
| 60 |
+
"model_type": "blt_local_decoder",
|
| 61 |
+
"vocab_size": 260,
|
| 62 |
+
"hidden_size": 512,
|
| 63 |
+
"hidden_size_global": 1024,
|
| 64 |
+
"num_hidden_layers": 9,
|
| 65 |
+
"num_attention_heads": 8,
|
| 66 |
+
"num_key_value_heads": 8,
|
| 67 |
+
"head_dim": 64,
|
| 68 |
+
"intermediate_size": 1365,
|
| 69 |
+
"rms_norm_eps": 1e-5,
|
| 70 |
+
"dropout": 0.0,
|
| 71 |
+
"max_position_embeddings": 24576,
|
| 72 |
+
"cross_attn_all_layers": true,
|
| 73 |
+
"cross_attn_k": 2,
|
| 74 |
+
"hidden_act": "silu",
|
| 75 |
+
"initializer_range": 0.02,
|
| 76 |
+
"rope_parameters": {"rope_type": "default",
|
| 77 |
+
"rope_theta": 500000
|
| 78 |
+
}
|
| 79 |
+
},
|
| 80 |
+
"global_config": {
|
| 81 |
+
"model_type": "blt_global_transformer",
|
| 82 |
+
"hidden_size": 1024,
|
| 83 |
+
"num_hidden_layers": 25,
|
| 84 |
+
"num_attention_heads": 8,
|
| 85 |
+
"num_key_value_heads": 8,
|
| 86 |
+
"head_dim": 128,
|
| 87 |
+
"intermediate_size": 2731,
|
| 88 |
+
"rms_norm_eps": 1e-5,
|
| 89 |
+
"dropout": 0.0,
|
| 90 |
+
"max_position_embeddings": 4096,
|
| 91 |
+
"hidden_act": "silu",
|
| 92 |
+
"initializer_range": 0.02,
|
| 93 |
+
"rope_parameters": {"rope_type": "default",
|
| 94 |
+
"rope_theta": 500000
|
| 95 |
+
},
|
| 96 |
+
"encoder_cross_output_size": null
|
| 97 |
+
}
|
| 98 |
+
}
|
transformer_pp_1b_c4/transformer_pp_1b_c4_valc4_soap_pdim2048_pfreq10_lr3e_3_b1_0_9_b2_0_95_eps_1e_15_20260508_191338/exp_data/configs/delta_net_1B.json
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"attn": null,
|
| 3 |
+
"attn_mode": "chunk",
|
| 4 |
+
"bos_token_id": 1,
|
| 5 |
+
"conv_size": 4,
|
| 6 |
+
"eos_token_id": 2,
|
| 7 |
+
"expand_k": 1,
|
| 8 |
+
"expand_v": 1,
|
| 9 |
+
"fuse_cross_entropy": true,
|
| 10 |
+
"fuse_norm": true,
|
| 11 |
+
"hidden_act": "swish",
|
| 12 |
+
"hidden_ratio": 4,
|
| 13 |
+
"hidden_size": 2048,
|
| 14 |
+
"initializer_range": 0.02,
|
| 15 |
+
"intermediate_size": null,
|
| 16 |
+
"model_type": "delta_net",
|
| 17 |
+
"norm_eps": 1e-06,
|
| 18 |
+
"num_heads": 16,
|
| 19 |
+
"num_hidden_layers": 24,
|
| 20 |
+
"pad_token_id": 2,
|
| 21 |
+
"qk_activation": "silu",
|
| 22 |
+
"qk_norm": "l2",
|
| 23 |
+
"tie_word_embeddings": false,
|
| 24 |
+
"use_beta": true,
|
| 25 |
+
"use_cache": true,
|
| 26 |
+
"use_gate": false,
|
| 27 |
+
"use_output_norm": true,
|
| 28 |
+
"use_short_conv": true
|
| 29 |
+
}
|
transformer_pp_1b_c4/transformer_pp_1b_c4_valc4_soap_pdim2048_pfreq10_lr3e_3_b1_0_9_b2_0_95_eps_1e_15_20260508_191338/exp_data/configs/delta_net_340M.json
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"attn_mode": "chunk",
|
| 3 |
+
"bos_token_id": 1,
|
| 4 |
+
"conv_size": 4,
|
| 5 |
+
"eos_token_id": 2,
|
| 6 |
+
"expand_k": 1,
|
| 7 |
+
"expand_v": 1,
|
| 8 |
+
"fuse_cross_entropy": true,
|
| 9 |
+
"hidden_act": "swish",
|
| 10 |
+
"hidden_ratio": 4,
|
| 11 |
+
"hidden_size": 1024,
|
| 12 |
+
"initializer_range": 0.02,
|
| 13 |
+
"intermediate_size": null,
|
| 14 |
+
"model_type": "delta_net",
|
| 15 |
+
"norm_eps": 1e-06,
|
| 16 |
+
"num_heads": 8,
|
| 17 |
+
"num_hidden_layers": 24,
|
| 18 |
+
"qk_activation": "silu",
|
| 19 |
+
"qk_norm": "l2",
|
| 20 |
+
"tie_word_embeddings": false,
|
| 21 |
+
"use_beta": true,
|
| 22 |
+
"use_cache": true,
|
| 23 |
+
"use_gate": false,
|
| 24 |
+
"use_output_norm": true,
|
| 25 |
+
"use_short_conv": true
|
| 26 |
+
}
|
transformer_pp_1b_c4/transformer_pp_1b_c4_valc4_soap_pdim2048_pfreq10_lr3e_3_b1_0_9_b2_0_95_eps_1e_15_20260508_191338/exp_data/configs/gated_deltanet_1B.json
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"attn_mode": "chunk",
|
| 3 |
+
"bos_token_id": 1,
|
| 4 |
+
"conv_size": 4,
|
| 5 |
+
"eos_token_id": 2,
|
| 6 |
+
"expand_v": 2,
|
| 7 |
+
"fuse_cross_entropy": true,
|
| 8 |
+
"head_dim": 256,
|
| 9 |
+
"hidden_act": "swish",
|
| 10 |
+
"hidden_ratio": 4,
|
| 11 |
+
"hidden_size": 2048,
|
| 12 |
+
"initializer_range": 0.02,
|
| 13 |
+
"intermediate_size": null,
|
| 14 |
+
"model_type": "gated_deltanet",
|
| 15 |
+
"norm_eps": 1e-06,
|
| 16 |
+
"num_heads": 6,
|
| 17 |
+
"num_hidden_layers": 21,
|
| 18 |
+
"tie_word_embeddings": false,
|
| 19 |
+
"use_cache": true,
|
| 20 |
+
"use_gate": true,
|
| 21 |
+
"use_short_conv": true
|
| 22 |
+
}
|
transformer_pp_1b_c4/transformer_pp_1b_c4_valc4_soap_pdim2048_pfreq10_lr3e_3_b1_0_9_b2_0_95_eps_1e_15_20260508_191338/exp_data/configs/gated_deltanet_340M.json
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"attn_mode": "chunk",
|
| 3 |
+
"bos_token_id": 1,
|
| 4 |
+
"conv_size": 4,
|
| 5 |
+
"eos_token_id": 2,
|
| 6 |
+
"expand_v": 2,
|
| 7 |
+
"fuse_cross_entropy": true,
|
| 8 |
+
"head_dim": 256,
|
| 9 |
+
"hidden_act": "swish",
|
| 10 |
+
"hidden_ratio": 4,
|
| 11 |
+
"hidden_size": 1024,
|
| 12 |
+
"initializer_range": 0.02,
|
| 13 |
+
"intermediate_size": null,
|
| 14 |
+
"model_type": "gated_deltanet",
|
| 15 |
+
"norm_eps": 1e-06,
|
| 16 |
+
"num_heads": 6,
|
| 17 |
+
"num_hidden_layers": 21,
|
| 18 |
+
"tie_word_embeddings": false,
|
| 19 |
+
"use_cache": true,
|
| 20 |
+
"use_gate": true,
|
| 21 |
+
"use_short_conv": true
|
| 22 |
+
}
|
transformer_pp_1b_c4/transformer_pp_1b_c4_valc4_soap_pdim2048_pfreq10_lr3e_3_b1_0_9_b2_0_95_eps_1e_15_20260508_191338/exp_data/configs/gated_deltanet_h_340M.json
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"model_type": "gated_deltanet",
|
| 3 |
+
"attn_mode": "chunk",
|
| 4 |
+
"hidden_size": 1024,
|
| 5 |
+
"num_hidden_layers": 21,
|
| 6 |
+
"head_dim": 256,
|
| 7 |
+
"num_heads": 6,
|
| 8 |
+
"expand_v": 2,
|
| 9 |
+
"hidden_ratio": 4,
|
| 10 |
+
"use_gate": true,
|
| 11 |
+
"use_short_conv": true,
|
| 12 |
+
"conv_size": 4,
|
| 13 |
+
"vocab_size": 32000,
|
| 14 |
+
"hidden_act": "swish",
|
| 15 |
+
"norm_eps": 1e-06,
|
| 16 |
+
"bos_token_id": 1,
|
| 17 |
+
"eos_token_id": 2,
|
| 18 |
+
"fuse_cross_entropy": true,
|
| 19 |
+
"initializer_range": 0.02,
|
| 20 |
+
"attn": {
|
| 21 |
+
"layers": [3, 7, 11, 15, 19],
|
| 22 |
+
"num_heads": 8,
|
| 23 |
+
"num_kv_heads": 1,
|
| 24 |
+
"window_size": 2048,
|
| 25 |
+
"rope_theta": 100000.0,
|
| 26 |
+
"qkv_bias": false
|
| 27 |
+
}
|
| 28 |
+
}
|
transformer_pp_1b_c4/transformer_pp_1b_c4_valc4_soap_pdim2048_pfreq10_lr3e_3_b1_0_9_b2_0_95_eps_1e_15_20260508_191338/exp_data/configs/gla_340M.json
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"attn_mode": "chunk",
|
| 3 |
+
"bos_token_id": 1,
|
| 4 |
+
"clamp_min": null,
|
| 5 |
+
"eos_token_id": 2,
|
| 6 |
+
"expand_k": 0.5,
|
| 7 |
+
"expand_v": 1,
|
| 8 |
+
"fuse_cross_entropy": true,
|
| 9 |
+
"fuse_norm": true,
|
| 10 |
+
"hidden_act": "swish",
|
| 11 |
+
"hidden_ratio": 4,
|
| 12 |
+
"hidden_size": 1024,
|
| 13 |
+
"initializer_range": 0.02,
|
| 14 |
+
"intermediate_size": null,
|
| 15 |
+
"model_type": "gla",
|
| 16 |
+
"num_heads": 4,
|
| 17 |
+
"num_hidden_layers": 24,
|
| 18 |
+
"norm_eps": 1e-06,
|
| 19 |
+
"tie_word_embeddings": false,
|
| 20 |
+
"use_cache": true,
|
| 21 |
+
"use_gk": true,
|
| 22 |
+
"use_gv": false,
|
| 23 |
+
"vocab_size": 32000
|
| 24 |
+
}
|
transformer_pp_1b_c4/transformer_pp_1b_c4_valc4_soap_pdim2048_pfreq10_lr3e_3_b1_0_9_b2_0_95_eps_1e_15_20260508_191338/exp_data/configs/gla_7B.json
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"attn": null,
|
| 3 |
+
"attn_mode": "chunk",
|
| 4 |
+
"bos_token_id": 1,
|
| 5 |
+
"eos_token_id": 2,
|
| 6 |
+
"expand_k": 0.5,
|
| 7 |
+
"expand_v": 1,
|
| 8 |
+
"fuse_cross_entropy": true,
|
| 9 |
+
"fuse_norm": true,
|
| 10 |
+
"hidden_act": "swish",
|
| 11 |
+
"hidden_ratio": 4,
|
| 12 |
+
"hidden_size": 4096,
|
| 13 |
+
"initializer_range": 0.02,
|
| 14 |
+
"intermediate_size": 11008,
|
| 15 |
+
"model_type": "gla",
|
| 16 |
+
"norm_eps": 1e-06,
|
| 17 |
+
"num_heads": 16,
|
| 18 |
+
"num_hidden_layers": 32,
|
| 19 |
+
"tie_word_embeddings": false,
|
| 20 |
+
"use_cache": true,
|
| 21 |
+
"use_gk": true,
|
| 22 |
+
"use_gv": false,
|
| 23 |
+
"use_output_gate": true,
|
| 24 |
+
"use_short_conv": false
|
| 25 |
+
}
|
transformer_pp_1b_c4/transformer_pp_1b_c4_valc4_soap_pdim2048_pfreq10_lr3e_3_b1_0_9_b2_0_95_eps_1e_15_20260508_191338/exp_data/configs/gsa_340M.json
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"bos_token_id": 1,
|
| 3 |
+
"conv_size": 4,
|
| 4 |
+
"eos_token_id": 2,
|
| 5 |
+
"expand_k": 1,
|
| 6 |
+
"expand_v": 1,
|
| 7 |
+
"elementwise_affine": false,
|
| 8 |
+
"feature_map": "swish",
|
| 9 |
+
"fuse_cross_entropy": true,
|
| 10 |
+
"fuse_norm": true,
|
| 11 |
+
"gate_logit_normalizer": 4,
|
| 12 |
+
"hidden_act": "swish",
|
| 13 |
+
"hidden_ratio": 4,
|
| 14 |
+
"hidden_size": 1024,
|
| 15 |
+
"initializer_range": 0.02,
|
| 16 |
+
"intermediate_size": null,
|
| 17 |
+
"model_type": "gsa",
|
| 18 |
+
"num_heads": 4,
|
| 19 |
+
"num_hidden_layers": 24,
|
| 20 |
+
"num_slots": 64,
|
| 21 |
+
"norm_eps": 1e-06,
|
| 22 |
+
"share_conv_kernel": true,
|
| 23 |
+
"tie_word_embeddings": false,
|
| 24 |
+
"use_cache": true,
|
| 25 |
+
"use_norm": true,
|
| 26 |
+
"use_output_gate": true,
|
| 27 |
+
"use_rope": false,
|
| 28 |
+
"use_short_conv": false
|
| 29 |
+
}
|
transformer_pp_1b_c4/transformer_pp_1b_c4_valc4_soap_pdim2048_pfreq10_lr3e_3_b1_0_9_b2_0_95_eps_1e_15_20260508_191338/exp_data/configs/mergenet_340M.json
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"model_type": "mergenet",
|
| 3 |
+
"vocab_size": 260,
|
| 4 |
+
"hidden_size": 1024,
|
| 5 |
+
"num_local_layers": 6,
|
| 6 |
+
"local_depth": 4,
|
| 7 |
+
"num_latent_layers": 12,
|
| 8 |
+
"num_heads": 16,
|
| 9 |
+
"num_kv_heads": 16,
|
| 10 |
+
"intermediate_size": 4096,
|
| 11 |
+
"hidden_act": "swish",
|
| 12 |
+
"max_position_embeddings": 8192,
|
| 13 |
+
"lambda_local": 4.0,
|
| 14 |
+
"dtem_window_size": 8,
|
| 15 |
+
"dtem_t": 1,
|
| 16 |
+
"dtem_feat_dim": null,
|
| 17 |
+
"use_softkmax": false,
|
| 18 |
+
"grid_bias_gamma": 1.0,
|
| 19 |
+
"W_infer": null,
|
| 20 |
+
"qkv_bias": true,
|
| 21 |
+
"qk_norm": false,
|
| 22 |
+
"rope_theta": 10000.0,
|
| 23 |
+
"norm_eps": 1e-6,
|
| 24 |
+
"initializer_range": 0.02,
|
| 25 |
+
"use_cache": true,
|
| 26 |
+
"pad_token_id": 0,
|
| 27 |
+
"bos_token_id": 1,
|
| 28 |
+
"eos_token_id": 2,
|
| 29 |
+
"tie_word_embeddings": false,
|
| 30 |
+
"phase": "phase2",
|
| 31 |
+
"drop_rate": 0.0,
|
| 32 |
+
"attn_drop_rate": 0.0,
|
| 33 |
+
"drop_path_rate": 0.1
|
| 34 |
+
}
|
transformer_pp_1b_c4/transformer_pp_1b_c4_valc4_soap_pdim2048_pfreq10_lr3e_3_b1_0_9_b2_0_95_eps_1e_15_20260508_191338/exp_data/configs/mergenet_64M.json
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"model_type": "mergenet",
|
| 3 |
+
"vocab_size": 32000,
|
| 4 |
+
"hidden_size": 512,
|
| 5 |
+
"num_local_layers": 4,
|
| 6 |
+
"local_depth": 4,
|
| 7 |
+
"num_latent_layers": 8,
|
| 8 |
+
"num_heads": 8,
|
| 9 |
+
"num_kv_heads": 8,
|
| 10 |
+
"intermediate_size": 2048,
|
| 11 |
+
"hidden_act": "swish",
|
| 12 |
+
"max_position_embeddings": 4096,
|
| 13 |
+
"lambda_local": 4.0,
|
| 14 |
+
"dtem_window_size": 8,
|
| 15 |
+
"dtem_t": 1,
|
| 16 |
+
"dtem_feat_dim": null,
|
| 17 |
+
"use_softkmax": false,
|
| 18 |
+
"grid_bias_gamma": 1.0,
|
| 19 |
+
"W_infer": null,
|
| 20 |
+
"qkv_bias": true,
|
| 21 |
+
"qk_norm": false,
|
| 22 |
+
"rope_theta": 10000.0,
|
| 23 |
+
"norm_eps": 1e-6,
|
| 24 |
+
"initializer_range": 0.02,
|
| 25 |
+
"use_cache": true,
|
| 26 |
+
"pad_token_id": 0,
|
| 27 |
+
"bos_token_id": 1,
|
| 28 |
+
"eos_token_id": 2,
|
| 29 |
+
"tie_word_embeddings": false,
|
| 30 |
+
"phase": "phase2",
|
| 31 |
+
"drop_rate": 0.0,
|
| 32 |
+
"attn_drop_rate": 0.0,
|
| 33 |
+
"drop_path_rate": 0.1
|
| 34 |
+
}
|
transformer_pp_1b_c4/transformer_pp_1b_c4_valc4_soap_pdim2048_pfreq10_lr3e_3_b1_0_9_b2_0_95_eps_1e_15_20260508_191338/exp_data/configs/qwen3_next_1B.json
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"model_type": "qwen3_next",
|
| 3 |
+
"vocab_size": 151936,
|
| 4 |
+
"hidden_size": 2048,
|
| 5 |
+
"intermediate_size": 5632,
|
| 6 |
+
"num_hidden_layers": 48,
|
| 7 |
+
"num_attention_heads": 16,
|
| 8 |
+
"num_key_value_heads": 2,
|
| 9 |
+
"head_dim": 256,
|
| 10 |
+
"hidden_act": "silu",
|
| 11 |
+
"max_position_embeddings": 32768,
|
| 12 |
+
"initializer_range": 0.02,
|
| 13 |
+
"rms_norm_eps": 1e-6,
|
| 14 |
+
"use_cache": true,
|
| 15 |
+
"tie_word_embeddings": false,
|
| 16 |
+
"attention_bias": false,
|
| 17 |
+
"attention_dropout": 0.0,
|
| 18 |
+
"rope_parameters": {
|
| 19 |
+
"rope_type": "default",
|
| 20 |
+
"factor": 1.0
|
| 21 |
+
},
|
| 22 |
+
"partial_rotary_factor": 0.25,
|
| 23 |
+
"layer_types": [
|
| 24 |
+
"linear_attention",
|
| 25 |
+
"linear_attention",
|
| 26 |
+
"linear_attention",
|
| 27 |
+
"full_attention"
|
| 28 |
+
],
|
| 29 |
+
"linear_conv_kernel_dim": 4,
|
| 30 |
+
"linear_key_head_dim": 128,
|
| 31 |
+
"linear_value_head_dim": 128,
|
| 32 |
+
"linear_num_key_heads": 16,
|
| 33 |
+
"linear_num_value_heads": 32,
|
| 34 |
+
"decoder_sparse_step": 1,
|
| 35 |
+
"moe_intermediate_size": 512,
|
| 36 |
+
"shared_expert_intermediate_size": 512,
|
| 37 |
+
"num_experts_per_tok": 10,
|
| 38 |
+
"num_experts": 512,
|
| 39 |
+
"norm_topk_prob": true,
|
| 40 |
+
"output_router_logits": false,
|
| 41 |
+
"router_aux_loss_coef": 0.001,
|
| 42 |
+
"mlp_only_layers": []
|
| 43 |
+
}
|
| 44 |
+
|
transformer_pp_1b_c4/transformer_pp_1b_c4_valc4_soap_pdim2048_pfreq10_lr3e_3_b1_0_9_b2_0_95_eps_1e_15_20260508_191338/exp_data/configs/qwen3_next_350M.json
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"model_type": "qwen3_next",
|
| 3 |
+
"vocab_size": 32000,
|
| 4 |
+
"hidden_size": 2048,
|
| 5 |
+
"intermediate_size": 5632,
|
| 6 |
+
"num_hidden_layers": 26,
|
| 7 |
+
"num_attention_heads": 16,
|
| 8 |
+
"num_key_value_heads": 2,
|
| 9 |
+
"head_dim": 256,
|
| 10 |
+
"hidden_act": "silu",
|
| 11 |
+
"max_position_embeddings": 32768,
|
| 12 |
+
"initializer_range": 0.02,
|
| 13 |
+
"rms_norm_eps": 1e-6,
|
| 14 |
+
"use_cache": true,
|
| 15 |
+
"tie_word_embeddings": false,
|
| 16 |
+
"attention_bias": false,
|
| 17 |
+
"attention_dropout": 0.0,
|
| 18 |
+
"rope_parameters": {
|
| 19 |
+
"rope_type": "default",
|
| 20 |
+
"factor": 1.0
|
| 21 |
+
},
|
| 22 |
+
"partial_rotary_factor": 0.25,
|
| 23 |
+
"layer_types": [
|
| 24 |
+
"linear_attention",
|
| 25 |
+
"linear_attention",
|
| 26 |
+
"linear_attention",
|
| 27 |
+
"full_attention"
|
| 28 |
+
],
|
| 29 |
+
"linear_conv_kernel_dim": 4,
|
| 30 |
+
"linear_key_head_dim": 128,
|
| 31 |
+
"linear_value_head_dim": 128,
|
| 32 |
+
"linear_num_key_heads": 16,
|
| 33 |
+
"linear_num_value_heads": 32,
|
| 34 |
+
"decoder_sparse_step": 1,
|
| 35 |
+
"moe_intermediate_size": 512,
|
| 36 |
+
"shared_expert_intermediate_size": 512,
|
| 37 |
+
"num_experts_per_tok": 10,
|
| 38 |
+
"num_experts": 512,
|
| 39 |
+
"norm_topk_prob": true,
|
| 40 |
+
"output_router_logits": false,
|
| 41 |
+
"router_aux_loss_coef": 0.001,
|
| 42 |
+
"mlp_only_layers": []
|
| 43 |
+
}
|
| 44 |
+
|
transformer_pp_1b_c4/transformer_pp_1b_c4_valc4_soap_pdim2048_pfreq10_lr3e_3_b1_0_9_b2_0_95_eps_1e_15_20260508_191338/exp_data/configs/transformer_1B.json
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"bos_token_id": 1,
|
| 3 |
+
"elementwise_affine": true,
|
| 4 |
+
"eos_token_id": 2,
|
| 5 |
+
"fuse_cross_entropy": true,
|
| 6 |
+
"fuse_norm": true,
|
| 7 |
+
"fuse_swiglu": true,
|
| 8 |
+
"hidden_act": "swish",
|
| 9 |
+
"hidden_ratio": 4,
|
| 10 |
+
"hidden_size": 2048,
|
| 11 |
+
"initializer_range": 0.02,
|
| 12 |
+
"intermediate_size": null,
|
| 13 |
+
"max_position_embeddings": 8192,
|
| 14 |
+
"model_type": "transformer",
|
| 15 |
+
"norm_eps": 1e-06,
|
| 16 |
+
"num_heads": 32,
|
| 17 |
+
"num_hidden_layers": 24,
|
| 18 |
+
"num_kv_heads": null,
|
| 19 |
+
"pad_token_id": 2,
|
| 20 |
+
"rope_theta": 10000.0,
|
| 21 |
+
"tie_word_embeddings": false
|
| 22 |
+
}
|
transformer_pp_1b_c4/transformer_pp_1b_c4_valc4_soap_pdim2048_pfreq10_lr3e_3_b1_0_9_b2_0_95_eps_1e_15_20260508_191338/exp_data/configs/transformer_340M.json
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"attention_bias": false,
|
| 3 |
+
"bos_token_id": 1,
|
| 4 |
+
"eos_token_id": 2,
|
| 5 |
+
"fuse_cross_entropy": true,
|
| 6 |
+
"fuse_norm": true,
|
| 7 |
+
"hidden_act": "swish",
|
| 8 |
+
"hidden_size": 1024,
|
| 9 |
+
"initializer_range": 0.02,
|
| 10 |
+
"max_position_embeddings": 8192,
|
| 11 |
+
"model_type": "transformer",
|
| 12 |
+
"num_heads": 16,
|
| 13 |
+
"num_hidden_layers": 24,
|
| 14 |
+
"norm_eps": 1e-06,
|
| 15 |
+
"tie_word_embeddings": false,
|
| 16 |
+
"use_cache": true,
|
| 17 |
+
"vocab_size": 32000
|
| 18 |
+
}
|
transformer_pp_1b_c4/transformer_pp_1b_c4_valc4_soap_pdim2048_pfreq10_lr3e_3_b1_0_9_b2_0_95_eps_1e_15_20260508_191338/exp_data/configs/transformer_7B.json
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"attention_bias": false,
|
| 3 |
+
"bos_token_id": 1,
|
| 4 |
+
"eos_token_id": 2,
|
| 5 |
+
"fuse_cross_entropy": true,
|
| 6 |
+
"fuse_norm": true,
|
| 7 |
+
"hidden_act": "swish",
|
| 8 |
+
"hidden_ratio": 4,
|
| 9 |
+
"hidden_size": 4096,
|
| 10 |
+
"initializer_range": 0.02,
|
| 11 |
+
"intermediate_size": 14336,
|
| 12 |
+
"model_type": "transformer",
|
| 13 |
+
"norm_eps": 1e-06,
|
| 14 |
+
"num_heads": 32,
|
| 15 |
+
"num_hidden_layers": 32,
|
| 16 |
+
"num_kv_heads": 8,
|
| 17 |
+
"rope_theta": 10000.0,
|
| 18 |
+
"tie_word_embeddings": false,
|
| 19 |
+
"use_cache": true,
|
| 20 |
+
"window_size": null
|
| 21 |
+
}
|
transformer_pp_1b_c4/transformer_pp_1b_c4_valc4_soap_pdim2048_pfreq10_lr3e_3_b1_0_9_b2_0_95_eps_1e_15_20260508_191338/exp_data/flame/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
__version__ = "0.1.0"
|
transformer_pp_1b_c4/transformer_pp_1b_c4_valc4_soap_pdim2048_pfreq10_lr3e_3_b1_0_9_b2_0_95_eps_1e_15_20260508_191338/exp_data/flame/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (207 Bytes). View file
|
|
|
transformer_pp_1b_c4/transformer_pp_1b_c4_valc4_soap_pdim2048_pfreq10_lr3e_3_b1_0_9_b2_0_95_eps_1e_15_20260508_191338/exp_data/flame/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (244 Bytes). View file
|
|
|
transformer_pp_1b_c4/transformer_pp_1b_c4_valc4_soap_pdim2048_pfreq10_lr3e_3_b1_0_9_b2_0_95_eps_1e_15_20260508_191338/exp_data/flame/__pycache__/__init__.cpython-313.pyc
ADDED
|
Binary file (238 Bytes). View file
|
|
|
transformer_pp_1b_c4/transformer_pp_1b_c4_valc4_soap_pdim2048_pfreq10_lr3e_3_b1_0_9_b2_0_95_eps_1e_15_20260508_191338/exp_data/flame/__pycache__/config_manager.cpython-310.pyc
ADDED
|
Binary file (29.6 kB). View file
|
|
|
transformer_pp_1b_c4/transformer_pp_1b_c4_valc4_soap_pdim2048_pfreq10_lr3e_3_b1_0_9_b2_0_95_eps_1e_15_20260508_191338/exp_data/flame/__pycache__/config_manager.cpython-311.pyc
ADDED
|
Binary file (41.5 kB). View file
|
|
|
transformer_pp_1b_c4/transformer_pp_1b_c4_valc4_soap_pdim2048_pfreq10_lr3e_3_b1_0_9_b2_0_95_eps_1e_15_20260508_191338/exp_data/flame/__pycache__/data.cpython-310.pyc
ADDED
|
Binary file (21.7 kB). View file
|
|
|
transformer_pp_1b_c4/transformer_pp_1b_c4_valc4_soap_pdim2048_pfreq10_lr3e_3_b1_0_9_b2_0_95_eps_1e_15_20260508_191338/exp_data/flame/__pycache__/data.cpython-311.pyc
ADDED
|
Binary file (41.6 kB). View file
|
|
|
transformer_pp_1b_c4/transformer_pp_1b_c4_valc4_soap_pdim2048_pfreq10_lr3e_3_b1_0_9_b2_0_95_eps_1e_15_20260508_191338/exp_data/flame/__pycache__/train.cpython-310.pyc
ADDED
|
Binary file (19.6 kB). View file
|
|
|
transformer_pp_1b_c4/transformer_pp_1b_c4_valc4_soap_pdim2048_pfreq10_lr3e_3_b1_0_9_b2_0_95_eps_1e_15_20260508_191338/exp_data/flame/__pycache__/train.cpython-311.pyc
ADDED
|
Binary file (41.2 kB). View file
|
|
|
transformer_pp_1b_c4/transformer_pp_1b_c4_valc4_soap_pdim2048_pfreq10_lr3e_3_b1_0_9_b2_0_95_eps_1e_15_20260508_191338/exp_data/flame/__pycache__/train.cpython-313.pyc
ADDED
|
Binary file (39.7 kB). View file
|
|
|
transformer_pp_1b_c4/transformer_pp_1b_c4_valc4_soap_pdim2048_pfreq10_lr3e_3_b1_0_9_b2_0_95_eps_1e_15_20260508_191338/exp_data/flame/c4_test.py
ADDED
|
@@ -0,0 +1,603 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
import os
|
| 7 |
+
import time
|
| 8 |
+
import json
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
import torch.nn as nn
|
| 12 |
+
import torch.distributed as dist
|
| 13 |
+
|
| 14 |
+
from tqdm import tqdm
|
| 15 |
+
from loguru import logger
|
| 16 |
+
|
| 17 |
+
import transformers
|
| 18 |
+
|
| 19 |
+
transformers.logging.set_verbosity_error()
|
| 20 |
+
|
| 21 |
+
import wandb
|
| 22 |
+
|
| 23 |
+
from utils.argparse import parse_args
|
| 24 |
+
from utils.setup import getting_svd_cnt, set_seed, setup_model, saving_model_weight, load_model_weight
|
| 25 |
+
from utils.optimizer_factory import setup_optimization
|
| 26 |
+
from utils.eval import evaluate_model
|
| 27 |
+
from utils.dataloader import setup_dataset
|
| 28 |
+
from utils.modeling_llama import LlamaForCausalLM
|
| 29 |
+
from utils.fake_quantization import QLinear
|
| 30 |
+
from utils.quantization import QScaleLinear
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def main(args):
|
| 34 |
+
import torch
|
| 35 |
+
############ Setup random seed ############
|
| 36 |
+
set_seed(args)
|
| 37 |
+
|
| 38 |
+
############ Setup DDP environment ############
|
| 39 |
+
assert "LOCAL_RANK" in os.environ, "torchrun should set LOCAL_RANK"
|
| 40 |
+
global_rank = int(os.environ["RANK"])
|
| 41 |
+
local_rank = int(os.environ["LOCAL_RANK"])
|
| 42 |
+
world_size = int(os.environ["WORLD_SIZE"])
|
| 43 |
+
torch.cuda.set_device(local_rank)
|
| 44 |
+
|
| 45 |
+
logger.info(f"Global rank {global_rank}, local rank {local_rank}, device: {torch.cuda.current_device()}")
|
| 46 |
+
dist.init_process_group(backend="nccl", rank=global_rank, world_size=world_size)
|
| 47 |
+
|
| 48 |
+
logger.info("Process group initialized")
|
| 49 |
+
device = f"cuda:{local_rank}"
|
| 50 |
+
|
| 51 |
+
if global_rank != 0:
|
| 52 |
+
logger.remove() # turn off logger
|
| 53 |
+
|
| 54 |
+
logger.info(f"Using dist with rank {global_rank} (only rank 0 will log)")
|
| 55 |
+
logger.info("*" * 40)
|
| 56 |
+
logger.info(f"Starting training with the arguments")
|
| 57 |
+
for k, v in vars(args).items():
|
| 58 |
+
logger.info(f"{k:30} {v}")
|
| 59 |
+
logger.info("*" * 40)
|
| 60 |
+
|
| 61 |
+
############ Initialize wandb without config (it is passed later) ############
|
| 62 |
+
if (not args.unset_wandb) and global_rank == 0:
|
| 63 |
+
if args.entity is None:
|
| 64 |
+
os.environ['WANDB_MODE'] = 'offline'
|
| 65 |
+
# Set wandb directory for offline mode
|
| 66 |
+
wandb_dir = getattr(args, 'wandb_dir', None) if getattr(args, 'wandb_dir', None) is not None else args.save_dir
|
| 67 |
+
if getattr(args, 'wandb_dir', None) is not None:
|
| 68 |
+
logger.info(f"Wandb directory set to: {wandb_dir}")
|
| 69 |
+
wandb.init(project=args.project, name=args.name, entity=args.entity, dir=wandb_dir)
|
| 70 |
+
|
| 71 |
+
############ Setup training data ############
|
| 72 |
+
if args.total_batch_size is not None:
|
| 73 |
+
if args.gradient_accumulation is None:
|
| 74 |
+
assert args.total_batch_size % world_size == 0, "total_batch_size must be divisible by world_size"
|
| 75 |
+
args.gradient_accumulation = args.total_batch_size // (args.batch_size * world_size)
|
| 76 |
+
assert args.gradient_accumulation > 0, "gradient_accumulation must be greater than 0"
|
| 77 |
+
|
| 78 |
+
assert (
|
| 79 |
+
args.gradient_accumulation * args.batch_size * world_size == args.total_batch_size
|
| 80 |
+
), "gradient_accumulation * batch_size * world_size must be equal to total_batch_size"
|
| 81 |
+
|
| 82 |
+
dataloader, tokenizer = setup_dataset(args, global_rank, world_size)
|
| 83 |
+
|
| 84 |
+
############ Initialize model ############
|
| 85 |
+
model_config, model = setup_model(args)
|
| 86 |
+
# Ensure model has generation_config (fix for transformers version compatibility)
|
| 87 |
+
if model.generation_config is None:
|
| 88 |
+
from transformers import GenerationConfig
|
| 89 |
+
model.generation_config = GenerationConfig()
|
| 90 |
+
model.generation_config.pad_token_id = tokenizer.pad_token_id
|
| 91 |
+
|
| 92 |
+
############ Resuming from checkpoints ############
|
| 93 |
+
global_step = 0
|
| 94 |
+
update_step = 0
|
| 95 |
+
beginning_step = 0
|
| 96 |
+
tokens_seen = 0
|
| 97 |
+
tokens_seen_before = 0
|
| 98 |
+
|
| 99 |
+
# identifying checkpointing
|
| 100 |
+
if args.continue_from is not None and os.path.exists(args.continue_from):
|
| 101 |
+
# searching the latest checkpoints
|
| 102 |
+
checkpoint_path_list = os.listdir(args.continue_from)
|
| 103 |
+
checkpoint_path_list = [int(x.split("_")[-1]) for x in checkpoint_path_list if x.startswith("model_")]
|
| 104 |
+
if len(checkpoint_path_list) > 0:
|
| 105 |
+
logger.info("Find Checkpoints", checkpoint_path_list)
|
| 106 |
+
beginning_step = max(checkpoint_path_list)
|
| 107 |
+
if args.resume_step is not None:
|
| 108 |
+
beginning_step = args.resume_step
|
| 109 |
+
args.continue_from = os.path.join(args.continue_from, f"model_{beginning_step}")
|
| 110 |
+
logger.info("Continue from", args.continue_from)
|
| 111 |
+
else:
|
| 112 |
+
logger.warning(f"Did not find any checkpoints in {args.continue_from}")
|
| 113 |
+
args.continue_from = None
|
| 114 |
+
|
| 115 |
+
# resuming from checkpointing
|
| 116 |
+
if args.continue_from is not None:
|
| 117 |
+
logger.info("*" * 40)
|
| 118 |
+
logger.info(f"Loading model from {args.continue_from}")
|
| 119 |
+
checkpoint_path = os.path.join(args.continue_from, "pytorch_model.bin")
|
| 120 |
+
if os.path.exists(checkpoint_path):
|
| 121 |
+
load_model_weight(model, checkpoint_path, args)
|
| 122 |
+
logger.info(f"Model successfully loaded (strict=False policy)")
|
| 123 |
+
else:
|
| 124 |
+
# Try safetensors format
|
| 125 |
+
checkpoint_path = os.path.join(args.continue_from, "model.safetensors")
|
| 126 |
+
if os.path.exists(checkpoint_path):
|
| 127 |
+
from safetensors import safe_open
|
| 128 |
+
tensors = {}
|
| 129 |
+
with safe_open(checkpoint_path, framework="pt", device=0) as f:
|
| 130 |
+
for k in f.keys():
|
| 131 |
+
tensors[k] = f.get_tensor(k)
|
| 132 |
+
print(k, tensors[k].shape)
|
| 133 |
+
ret = model.load_state_dict(tensors, strict=False)
|
| 134 |
+
logger.info(f"Model successfully loaded from safetensors (strict=False policy)", ret)
|
| 135 |
+
else:
|
| 136 |
+
logger.warning(f"No model checkpoint found in {args.continue_from}")
|
| 137 |
+
|
| 138 |
+
if os.path.exists(os.path.join(args.continue_from, "training_state.json")):
|
| 139 |
+
logger.info(
|
| 140 |
+
f"Loading training state like global_step, update_step, and tokens_seen from {args.continue_from}"
|
| 141 |
+
)
|
| 142 |
+
with open(os.path.join(args.continue_from, "training_state.json")) as f:
|
| 143 |
+
_old_state = json.load(f)
|
| 144 |
+
global_step = _old_state["global_step"]
|
| 145 |
+
update_step = _old_state["update_step"]
|
| 146 |
+
tokens_seen = _old_state["tokens_seen"]
|
| 147 |
+
tokens_seen_before = _old_state["tokens_seen_before"]
|
| 148 |
+
logger.info(f"global_step : {global_step}")
|
| 149 |
+
logger.info(f"update_step : {update_step}")
|
| 150 |
+
logger.info(f"tokens_seen : {tokens_seen}")
|
| 151 |
+
logger.info(f"tokens_seen_before: {tokens_seen_before}")
|
| 152 |
+
logger.info(f"Will train for {args.num_training_steps - update_step} update steps")
|
| 153 |
+
else:
|
| 154 |
+
logger.warning(f"Did not find training state in {args.continue_from}, global step will start from zero")
|
| 155 |
+
logger.info("*" * 40)
|
| 156 |
+
|
| 157 |
+
############ Setup model ############
|
| 158 |
+
if args.dtype in ["bf16", "bfloat16"]:
|
| 159 |
+
model = model.to(dtype=torch.bfloat16)
|
| 160 |
+
model = model.to(device=device)
|
| 161 |
+
|
| 162 |
+
for _, module in model.named_modules():
|
| 163 |
+
if isinstance(module, QScaleLinear):
|
| 164 |
+
weight_device = module.weight.device
|
| 165 |
+
module.weight.scales = module.weight.scales.to(device=weight_device)
|
| 166 |
+
module.weight.zeros = module.weight.zeros.to(device=weight_device)
|
| 167 |
+
|
| 168 |
+
n_total_params = sum(p.numel() for p in model.parameters())
|
| 169 |
+
trainable_params = [p for p in model.parameters() if p.requires_grad]
|
| 170 |
+
trainable_params_int8 = [p for p in model.parameters() if hasattr(p, "group_size")]
|
| 171 |
+
|
| 172 |
+
############ Initialize wandb ############
|
| 173 |
+
run_config = dict(vars(args))
|
| 174 |
+
run_config.update(
|
| 175 |
+
{
|
| 176 |
+
"max_lr": run_config.pop("lr"), # rename lr to max_lr to avoid conflicts with scheduler
|
| 177 |
+
"total_params_M": n_total_params / 1_000_000,
|
| 178 |
+
"dataset": "c4",
|
| 179 |
+
"model": model_config.to_dict(),
|
| 180 |
+
"world_size": world_size,
|
| 181 |
+
"device": str(device),
|
| 182 |
+
}
|
| 183 |
+
)
|
| 184 |
+
|
| 185 |
+
if global_rank == 0:
|
| 186 |
+
if not args.unset_wandb:
|
| 187 |
+
wandb.config.update(run_config, allow_val_change=True)
|
| 188 |
+
wandb.save(os.path.abspath(__file__), policy="now") # save current script
|
| 189 |
+
# fix tqdm visual length to 80 so that the progress bar
|
| 190 |
+
# doesn't jump around when changing from external display to laptop
|
| 191 |
+
pbar = tqdm(total=args.num_training_steps - update_step, desc="Update steps", ncols=80)
|
| 192 |
+
|
| 193 |
+
############ Initialize optimization ############
|
| 194 |
+
if "galore" in args.optimizer.lower():
|
| 195 |
+
# make parameters with "rank" to a single group, if param_name has "mlp" or "attn"
|
| 196 |
+
lowrank_params = []
|
| 197 |
+
target_modules_list = ["attn", "mlp"]
|
| 198 |
+
for module_name, module in model.named_modules():
|
| 199 |
+
if not (isinstance(module, nn.Linear) or isinstance(module, QScaleLinear) or isinstance(module, QLinear)):
|
| 200 |
+
continue
|
| 201 |
+
if not any(target_key in module_name for target_key in target_modules_list):
|
| 202 |
+
continue
|
| 203 |
+
logger.info(f"Adding {module_name} to GaLore parameters")
|
| 204 |
+
lowrank_params.append(module.weight)
|
| 205 |
+
|
| 206 |
+
id_lowrank_params = [id(p) for p in lowrank_params]
|
| 207 |
+
# make parameters without "rank" to another group
|
| 208 |
+
regular_params = [p for p in model.parameters() if id(p) not in id_lowrank_params]
|
| 209 |
+
# then call low rank optimizer
|
| 210 |
+
param_groups = [
|
| 211 |
+
{"params": regular_params},
|
| 212 |
+
{
|
| 213 |
+
"params": lowrank_params,
|
| 214 |
+
"rank": args.rank,
|
| 215 |
+
"update_proj_gap": args.update_proj_gap,
|
| 216 |
+
"scale": args.galore_scale,
|
| 217 |
+
"proj_type": args.proj_type,
|
| 218 |
+
"quant": args.proj_quant,
|
| 219 |
+
"quant_n_bit": args.proj_bits,
|
| 220 |
+
"quant_group_size": args.proj_group_size,
|
| 221 |
+
"cos_threshold": args.cos_threshold,
|
| 222 |
+
"gamma_proj": args.gamma_proj,
|
| 223 |
+
"queue_size": args.queue_size,
|
| 224 |
+
},
|
| 225 |
+
]
|
| 226 |
+
elif "apollo" in args.optimizer.lower():
|
| 227 |
+
# make parameters with "rank" to a single group, if param_name has "mlp" or "attn"
|
| 228 |
+
lowrank_params = []
|
| 229 |
+
target_modules_list = ["attn", "mlp"]
|
| 230 |
+
for module_name, module in model.named_modules():
|
| 231 |
+
if not (isinstance(module, nn.Linear) or isinstance(module, QScaleLinear) or isinstance(module, QLinear)):
|
| 232 |
+
continue
|
| 233 |
+
if not any(target_key in module_name for target_key in target_modules_list):
|
| 234 |
+
continue
|
| 235 |
+
logger.info(f"Adding {module_name} to APOLLO parameters")
|
| 236 |
+
lowrank_params.append(module.weight)
|
| 237 |
+
|
| 238 |
+
id_lowrank_params = [id(p) for p in lowrank_params]
|
| 239 |
+
# make parameters without "rank" to another group
|
| 240 |
+
regular_params = [p for p in model.parameters() if id(p) not in id_lowrank_params]
|
| 241 |
+
# then call low rank optimizer
|
| 242 |
+
param_groups = [
|
| 243 |
+
{"params": regular_params},
|
| 244 |
+
{
|
| 245 |
+
"params": lowrank_params,
|
| 246 |
+
"rank": args.rank,
|
| 247 |
+
"update_proj_gap": args.update_proj_gap,
|
| 248 |
+
"scale": args.apollo_scale,
|
| 249 |
+
"proj_type": args.proj_type,
|
| 250 |
+
"proj": args.proj,
|
| 251 |
+
"scale_type": args.scale_type,
|
| 252 |
+
},
|
| 253 |
+
]
|
| 254 |
+
elif "conda" in args.optimizer.lower():
|
| 255 |
+
# make parameters with "rank" to a single group, if param_name has "mlp" or "attn"
|
| 256 |
+
lowrank_params = []
|
| 257 |
+
target_modules_list = ["attn", "mlp"]
|
| 258 |
+
for module_name, module in model.named_modules():
|
| 259 |
+
if not (isinstance(module, nn.Linear) or isinstance(module, QScaleLinear) or isinstance(module, QLinear)):
|
| 260 |
+
continue
|
| 261 |
+
if not any(target_key in module_name for target_key in target_modules_list):
|
| 262 |
+
continue
|
| 263 |
+
logger.info(f"Adding {module_name} to conda parameters")
|
| 264 |
+
lowrank_params.append(module.weight)
|
| 265 |
+
|
| 266 |
+
id_lowrank_params = [id(p) for p in lowrank_params]
|
| 267 |
+
# make parameters without "rank" to another group
|
| 268 |
+
regular_params = [p for p in model.parameters() if id(p) not in id_lowrank_params]
|
| 269 |
+
# then call low rank optimizer
|
| 270 |
+
param_groups = [
|
| 271 |
+
{"params": regular_params},
|
| 272 |
+
{
|
| 273 |
+
"params": lowrank_params,
|
| 274 |
+
"rank": args.rank,
|
| 275 |
+
"update_proj_gap": args.update_proj_gap,
|
| 276 |
+
"scale": args.apollo_scale,
|
| 277 |
+
"proj_type": args.proj_type,
|
| 278 |
+
"proj": args.proj,
|
| 279 |
+
"scale_type": args.scale_type,
|
| 280 |
+
},
|
| 281 |
+
]
|
| 282 |
+
else:
|
| 283 |
+
param_groups = None
|
| 284 |
+
id_lowrank_params = None
|
| 285 |
+
|
| 286 |
+
# print params and trainable params
|
| 287 |
+
logger.info(f"\n{model}\n")
|
| 288 |
+
logger.info(f"Total params: {sum(p.numel() for p in model.parameters()) / 1_000_000:.2f}M")
|
| 289 |
+
|
| 290 |
+
if args.simulation:
|
| 291 |
+
num_train_params = sum(p.numel() for p in trainable_params)
|
| 292 |
+
else:
|
| 293 |
+
num_train_params = sum(p.numel() for p in trainable_params) + sum(p.numel() for p in trainable_params_int8)
|
| 294 |
+
|
| 295 |
+
logger.info(f"Trainable params: {num_train_params / 1_000_000:.2f}M")
|
| 296 |
+
if "q_galore" in args.optimizer.lower():
|
| 297 |
+
logger.info(
|
| 298 |
+
f"Trainable params with Q-GaLore enabled: {sum(p.numel() for p in trainable_params_int8) / 1_000_000:.2f}M"
|
| 299 |
+
)
|
| 300 |
+
elif "galore" in args.optimizer.lower():
|
| 301 |
+
logger.info(f"Total params with GaLore enabled: {sum(p.numel() for p in lowrank_params) / 1_000_000:.2f}M")
|
| 302 |
+
elif "q_apollo" in args.optimizer.lower():
|
| 303 |
+
logger.info(
|
| 304 |
+
f"Trainable params with Q-APOLLO enabled: {sum(p.numel() for p in trainable_params_int8) / 1_000_000:.2f}M"
|
| 305 |
+
)
|
| 306 |
+
elif "apollo" in args.optimizer.lower():
|
| 307 |
+
logger.info(f"Total params with APOLLO enabled: {sum(p.numel() for p in lowrank_params) / 1_000_000:.2f}M")
|
| 308 |
+
|
| 309 |
+
logger.info(f"Saving model to {args.save_dir} every {args.save_every} update steps")
|
| 310 |
+
|
| 311 |
+
model, optimizer, scheduler, layer_wise_flag = setup_optimization(
|
| 312 |
+
args, model, trainable_params, param_groups, id_lowrank_params, model_config
|
| 313 |
+
)
|
| 314 |
+
|
| 315 |
+
if layer_wise_flag:
|
| 316 |
+
# will pass optimizer_dict and scheduler_dict out instead of optimizer and scheduler
|
| 317 |
+
optimizer_dict = optimizer
|
| 318 |
+
scheduler_dict = scheduler
|
| 319 |
+
|
| 320 |
+
# Bug-3 fix: wrap with DDP *before* torch.compile per PyTorch recommendation.
|
| 321 |
+
# This ensures gradient reduction hooks are correctly installed on the DDP module,
|
| 322 |
+
# and the compiled graph captures the full DDP+model forward pass.
|
| 323 |
+
# (Issue-5: optimizer.load_state_dict is called after both DDP and compile below.)
|
| 324 |
+
if not args.single_gpu:
|
| 325 |
+
model: LlamaForCausalLM = torch.nn.parallel.DistributedDataParallel(
|
| 326 |
+
model,
|
| 327 |
+
device_ids=[local_rank],
|
| 328 |
+
output_device=local_rank,
|
| 329 |
+
broadcast_buffers=False,
|
| 330 |
+
)
|
| 331 |
+
|
| 332 |
+
# compile the model (after DDP so the compiled graph includes DDP reduction)
|
| 333 |
+
if args.compile:
|
| 334 |
+
print("Compiling the model... (takes a ~minute)")
|
| 335 |
+
unoptimized_model = model
|
| 336 |
+
|
| 337 |
+
# Configure TorchDynamo to suppress errors and fall back to eager mode
|
| 338 |
+
import torch._dynamo
|
| 339 |
+
torch._dynamo.config.suppress_errors = args.dynamo_suppress_errors
|
| 340 |
+
torch._dynamo.config.verbose = False
|
| 341 |
+
# Set cache size limit to prevent memory issues during long training
|
| 342 |
+
torch._dynamo.config.cache_size_limit = args.dynamo_cache_limit
|
| 343 |
+
|
| 344 |
+
model = torch.compile(model) # requires PyTorch 2.0
|
| 345 |
+
|
| 346 |
+
# resume optimizer
|
| 347 |
+
if args.restore_optimizer and args.continue_from is not None:
|
| 348 |
+
logger.info("Restoring optimizer and scheduler from the checkpoint")
|
| 349 |
+
_optimizer_dir = args.continue_from
|
| 350 |
+
optimizer_checkpoint = torch.load(os.path.join(_optimizer_dir, "optimizer.pt"), map_location="cpu")
|
| 351 |
+
optimizer.load_state_dict(optimizer_checkpoint["optimizer"])
|
| 352 |
+
scheduler.load_state_dict(optimizer_checkpoint["scheduler"])
|
| 353 |
+
update_step = optimizer_checkpoint["update_step"]
|
| 354 |
+
beginning_step = update_step
|
| 355 |
+
global_step = optimizer_checkpoint["global_step"]
|
| 356 |
+
logger.info(f"Optimizer and scheduler restored from {_optimizer_dir}")
|
| 357 |
+
|
| 358 |
+
# ##############################
|
| 359 |
+
# TRAINING LOOP
|
| 360 |
+
# we use iterable dataset, so we may never go through all the data
|
| 361 |
+
# ##############################
|
| 362 |
+
# global steps and others are defined above
|
| 363 |
+
pad_idx = tokenizer.pad_token_id
|
| 364 |
+
update_time = time.time()
|
| 365 |
+
local_step = 0 # when continue_from is used, local_step != global_step
|
| 366 |
+
total_svd_count = 0
|
| 367 |
+
|
| 368 |
+
dataloader_iter = iter(dataloader)
|
| 369 |
+
|
| 370 |
+
# Issue-4 fix: accumulate loss across micro-batches so logged loss is the true
|
| 371 |
+
# gradient-accumulation average, not just the last micro-batch.
|
| 372 |
+
accumulated_loss = 0.0
|
| 373 |
+
|
| 374 |
+
# Skip data if resuming from checkpoint
|
| 375 |
+
if update_step != 0:
|
| 376 |
+
skip_batches = args.gradient_accumulation * update_step
|
| 377 |
+
logger.info(f"Skipping {skip_batches} batches to resume from update step {update_step}")
|
| 378 |
+
skipped = 0
|
| 379 |
+
for _ in range(skip_batches):
|
| 380 |
+
# Issue-6 fix: handle StopIteration during skip so all ranks stay aligned
|
| 381 |
+
try:
|
| 382 |
+
next(dataloader_iter)
|
| 383 |
+
except StopIteration:
|
| 384 |
+
logger.warning(
|
| 385 |
+
f"Dataset exhausted during skip at batch {skipped}/{skip_batches}; "
|
| 386 |
+
f"restarting iterator to keep ranks aligned."
|
| 387 |
+
)
|
| 388 |
+
dataloader_iter = iter(dataloader)
|
| 389 |
+
next(dataloader_iter)
|
| 390 |
+
skipped += 1
|
| 391 |
+
logger.info(f"Skipped {skipped} batches successfully")
|
| 392 |
+
|
| 393 |
+
while update_step <= args.num_training_steps:
|
| 394 |
+
try:
|
| 395 |
+
batch = next(dataloader_iter)
|
| 396 |
+
except StopIteration:
|
| 397 |
+
logger.info(f"Dataset completed one epoch. Starting new epoch with reshuffled data.")
|
| 398 |
+
dataloader_iter = iter(dataloader)
|
| 399 |
+
batch = next(dataloader_iter)
|
| 400 |
+
|
| 401 |
+
global_step += 1
|
| 402 |
+
local_step += 1
|
| 403 |
+
|
| 404 |
+
if update_step >= args.num_training_steps:
|
| 405 |
+
logger.info(f"Reached max number of update steps ({args.num_training_steps}). Stopping training.")
|
| 406 |
+
logger.info(f"Rank {global_rank} stopping training.")
|
| 407 |
+
break
|
| 408 |
+
|
| 409 |
+
# forward & backward
|
| 410 |
+
batch = {k: v.to(device) for k, v in batch.items()}
|
| 411 |
+
labels = batch["input_ids"].clone()
|
| 412 |
+
labels[labels == pad_idx] = -100
|
| 413 |
+
tokens_seen += (batch["input_ids"] != pad_idx).sum().item() * world_size
|
| 414 |
+
|
| 415 |
+
loss = model(**batch, labels=labels).loss
|
| 416 |
+
|
| 417 |
+
scaled_loss = loss / args.gradient_accumulation
|
| 418 |
+
scaled_loss.backward()
|
| 419 |
+
accumulated_loss += loss.item() # Issue-4: accumulate before the continue
|
| 420 |
+
|
| 421 |
+
if global_step % args.gradient_accumulation != 0:
|
| 422 |
+
continue
|
| 423 |
+
|
| 424 |
+
# The below code is only executed during the update step
|
| 425 |
+
# Issue-4: compute average loss over all micro-batches in this accumulation window
|
| 426 |
+
avg_loss = accumulated_loss / args.gradient_accumulation
|
| 427 |
+
accumulated_loss = 0.0 # reset for next accumulation window
|
| 428 |
+
# add grad clipping: TODO: add gradient clipping of int8 weight
|
| 429 |
+
if args.grad_clipping != 0.0:
|
| 430 |
+
torch.nn.utils.clip_grad_norm_(trainable_params, args.grad_clipping)
|
| 431 |
+
# Periodic memory cleanup to prevent symbolic tensor issues during long training
|
| 432 |
+
if global_step % args.memory_cleanup_frequency == 0:
|
| 433 |
+
torch.cuda.empty_cache()
|
| 434 |
+
# Clear TorchDynamo cache to prevent memory accumulation
|
| 435 |
+
if args.compile:
|
| 436 |
+
import torch._dynamo
|
| 437 |
+
torch._dynamo.reset()
|
| 438 |
+
|
| 439 |
+
if global_rank == 0:
|
| 440 |
+
pbar.update(1)
|
| 441 |
+
if not layer_wise_flag: # layer-wise updation is done during backward; requires gradient_accumulation equals 1
|
| 442 |
+
optimizer.step()
|
| 443 |
+
scheduler.step()
|
| 444 |
+
optimizer.zero_grad()
|
| 445 |
+
|
| 446 |
+
update_step += 1
|
| 447 |
+
update_time = time.time() - update_time
|
| 448 |
+
|
| 449 |
+
# save checkpoint by save_every
|
| 450 |
+
if local_step > args.gradient_accumulation and update_step % args.save_every == 0 and global_rank == 0:
|
| 451 |
+
current_model_directory = f"{args.save_dir}/model_{update_step}"
|
| 452 |
+
logger.info(f"Saving model and optimizer to {current_model_directory}, update step {update_step}")
|
| 453 |
+
os.makedirs(args.save_dir, exist_ok=True)
|
| 454 |
+
# Bug-1 fix: unwrap DDP/compiled model for saving; works in both single-GPU and multi-GPU modes
|
| 455 |
+
unwrapped_model = model.module if hasattr(model, 'module') else model
|
| 456 |
+
unwrapped_model.save_pretrained(current_model_directory, max_shard_size="500GB", from_pt=True)
|
| 457 |
+
saving_model_weight(unwrapped_model, f"{current_model_directory}/pytorch_model.bin", args)
|
| 458 |
+
|
| 459 |
+
optimizer_checkpoint = {
|
| 460 |
+
"optimizer": optimizer.state_dict(),
|
| 461 |
+
"scheduler": scheduler.state_dict(),
|
| 462 |
+
"update_step": update_step,
|
| 463 |
+
"global_step": global_step,
|
| 464 |
+
"config": run_config,
|
| 465 |
+
"wandb": wandb.run.dir if not args.unset_wandb else None,
|
| 466 |
+
"dtype": args.dtype,
|
| 467 |
+
}
|
| 468 |
+
torch.save(optimizer_checkpoint, f"{current_model_directory}/optimizer.pt")
|
| 469 |
+
|
| 470 |
+
training_state_checkpoint = {
|
| 471 |
+
"global_step": global_step,
|
| 472 |
+
"update_step": update_step,
|
| 473 |
+
"tokens_seen": tokens_seen,
|
| 474 |
+
"tokens_seen_before": tokens_seen_before,
|
| 475 |
+
"update_time": update_time,
|
| 476 |
+
}
|
| 477 |
+
with open(f"{current_model_directory}/training_state.json", "w") as f:
|
| 478 |
+
json.dump(training_state_checkpoint, f, indent=4)
|
| 479 |
+
|
| 480 |
+
# save wandb related info
|
| 481 |
+
if not args.unset_wandb:
|
| 482 |
+
wandb_info = {
|
| 483 |
+
"wandb_id": wandb.run.id,
|
| 484 |
+
}
|
| 485 |
+
with open(f"{args.save_dir}/wandb.json", "w") as f:
|
| 486 |
+
json.dump(wandb_info, f, indent=4)
|
| 487 |
+
|
| 488 |
+
# evaluation
|
| 489 |
+
if update_step % args.eval_every == 0:
|
| 490 |
+
logger.info(f"Performing evaluation at step {update_step}")
|
| 491 |
+
total_loss, evaluated_on_tokens, perplexity = evaluate_model(
|
| 492 |
+
model, tokenizer, pad_idx, global_rank, world_size, device, args
|
| 493 |
+
)
|
| 494 |
+
|
| 495 |
+
if global_rank == 0:
|
| 496 |
+
if not args.unset_wandb:
|
| 497 |
+
wandb.log(
|
| 498 |
+
{
|
| 499 |
+
"eval_loss": total_loss,
|
| 500 |
+
"eval_perplexity": perplexity,
|
| 501 |
+
"eval_tokens": evaluated_on_tokens,
|
| 502 |
+
},
|
| 503 |
+
step=update_step,
|
| 504 |
+
)
|
| 505 |
+
logger.info(f"Eval loss at step {update_step}: {total_loss}, Eval perplexity: {perplexity}")
|
| 506 |
+
|
| 507 |
+
if not layer_wise_flag:
|
| 508 |
+
lr = optimizer.param_groups[0]["lr"]
|
| 509 |
+
else:
|
| 510 |
+
lr = list(optimizer_dict.values())[0].param_groups[0]["lr"]
|
| 511 |
+
tokens_in_update = tokens_seen - tokens_seen_before
|
| 512 |
+
tokens_seen_before = tokens_seen
|
| 513 |
+
batches_in_update = args.gradient_accumulation * world_size
|
| 514 |
+
if not layer_wise_flag:
|
| 515 |
+
total_svd_count = getting_svd_cnt(optimizer)
|
| 516 |
+
else:
|
| 517 |
+
total_svd_count = 0
|
| 518 |
+
|
| 519 |
+
if global_rank == 0:
|
| 520 |
+
if not args.unset_wandb:
|
| 521 |
+
wandb.log(
|
| 522 |
+
{
|
| 523 |
+
"loss": avg_loss,
|
| 524 |
+
"lr": lr,
|
| 525 |
+
"update_step": update_step,
|
| 526 |
+
"tokens_seen": tokens_seen,
|
| 527 |
+
"total_svd_count": total_svd_count,
|
| 528 |
+
"throughput_tokens": tokens_in_update / update_time,
|
| 529 |
+
"throughput_examples": args.total_batch_size / update_time,
|
| 530 |
+
"throughput_batches": batches_in_update / update_time,
|
| 531 |
+
},
|
| 532 |
+
step=update_step,
|
| 533 |
+
)
|
| 534 |
+
update_time = time.time()
|
| 535 |
+
|
| 536 |
+
# ##############################
|
| 537 |
+
# END of training loop
|
| 538 |
+
# ##############################
|
| 539 |
+
logger.info("Training finished")
|
| 540 |
+
if global_rank == 0:
|
| 541 |
+
pbar.close()
|
| 542 |
+
|
| 543 |
+
current_model_directory = f"{args.save_dir}/model_{update_step}"
|
| 544 |
+
if global_rank == 0 and not os.path.exists(current_model_directory):
|
| 545 |
+
logger.info(f"Saving model and optimizer to {current_model_directory}, update step {update_step}")
|
| 546 |
+
os.makedirs(args.save_dir, exist_ok=True)
|
| 547 |
+
# Bug-1 fix: unwrap DDP/compiled model for saving; works in both single-GPU and multi-GPU modes
|
| 548 |
+
unwrapped_model = model.module if hasattr(model, 'module') else model
|
| 549 |
+
unwrapped_model.save_pretrained(current_model_directory, max_shard_size="500GB", from_pt=True)
|
| 550 |
+
saving_model_weight(unwrapped_model, f"{current_model_directory}/pytorch_model.bin", args)
|
| 551 |
+
|
| 552 |
+
optimizer_checkpoint = {
|
| 553 |
+
"optimizer": optimizer.state_dict(),
|
| 554 |
+
"scheduler": scheduler.state_dict(),
|
| 555 |
+
"update_step": update_step,
|
| 556 |
+
"global_step": global_step,
|
| 557 |
+
"config": run_config,
|
| 558 |
+
"wandb": wandb.run.dir if not args.unset_wandb else None,
|
| 559 |
+
"dtype": args.dtype,
|
| 560 |
+
}
|
| 561 |
+
torch.save(optimizer_checkpoint, f"{current_model_directory}/optimizer.pt")
|
| 562 |
+
|
| 563 |
+
training_state_checkpoint = {
|
| 564 |
+
"global_step": global_step,
|
| 565 |
+
"update_step": update_step,
|
| 566 |
+
"tokens_seen": tokens_seen,
|
| 567 |
+
"tokens_seen_before": tokens_seen_before,
|
| 568 |
+
"update_time": update_time,
|
| 569 |
+
}
|
| 570 |
+
with open(f"{current_model_directory}/training_state.json", "w") as f:
|
| 571 |
+
json.dump(training_state_checkpoint, f, indent=4)
|
| 572 |
+
|
| 573 |
+
# Final evaluation
|
| 574 |
+
logger.info("Running final evaluation")
|
| 575 |
+
model.eval()
|
| 576 |
+
del loss, optimizer, scheduler
|
| 577 |
+
import gc
|
| 578 |
+
|
| 579 |
+
gc.collect()
|
| 580 |
+
torch.cuda.empty_cache()
|
| 581 |
+
|
| 582 |
+
total_loss, evaluated_on_tokens, perplexity = evaluate_model(model, tokenizer, pad_idx, global_rank, world_size, device, args)
|
| 583 |
+
|
| 584 |
+
if global_rank == 0:
|
| 585 |
+
if not args.unset_wandb:
|
| 586 |
+
wandb.log(
|
| 587 |
+
{
|
| 588 |
+
"final_eval_loss": total_loss,
|
| 589 |
+
"final_eval_perplexity": perplexity,
|
| 590 |
+
"final_eval_tokens": evaluated_on_tokens,
|
| 591 |
+
},
|
| 592 |
+
step=update_step,
|
| 593 |
+
)
|
| 594 |
+
logger.info(f"Final eval loss: {total_loss}, Final eval perplexity: {perplexity}")
|
| 595 |
+
|
| 596 |
+
logger.info("Script finished successfully")
|
| 597 |
+
print(f"Rank {global_rank} finished successfully")
|
| 598 |
+
|
| 599 |
+
|
| 600 |
+
if __name__ == "__main__":
|
| 601 |
+
print("Starting script")
|
| 602 |
+
args = parse_args(None)
|
| 603 |
+
main(args)
|
transformer_pp_1b_c4/transformer_pp_1b_c4_valc4_soap_pdim2048_pfreq10_lr3e_3_b1_0_9_b2_0_95_eps_1e_15_20260508_191338/exp_data/flame/components/__init__.py
ADDED
|
File without changes
|
transformer_pp_1b_c4/transformer_pp_1b_c4_valc4_soap_pdim2048_pfreq10_lr3e_3_b1_0_9_b2_0_95_eps_1e_15_20260508_191338/exp_data/flame/components/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (197 Bytes). View file
|
|
|
transformer_pp_1b_c4/transformer_pp_1b_c4_valc4_soap_pdim2048_pfreq10_lr3e_3_b1_0_9_b2_0_95_eps_1e_15_20260508_191338/exp_data/flame/components/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (235 Bytes). View file
|
|
|
transformer_pp_1b_c4/transformer_pp_1b_c4_valc4_soap_pdim2048_pfreq10_lr3e_3_b1_0_9_b2_0_95_eps_1e_15_20260508_191338/exp_data/flame/components/__pycache__/checkpoint.cpython-310.pyc
ADDED
|
Binary file (1.92 kB). View file
|
|
|
transformer_pp_1b_c4/transformer_pp_1b_c4_valc4_soap_pdim2048_pfreq10_lr3e_3_b1_0_9_b2_0_95_eps_1e_15_20260508_191338/exp_data/flame/components/__pycache__/checkpoint.cpython-311.pyc
ADDED
|
Binary file (3.7 kB). View file
|
|
|
transformer_pp_1b_c4/transformer_pp_1b_c4_valc4_soap_pdim2048_pfreq10_lr3e_3_b1_0_9_b2_0_95_eps_1e_15_20260508_191338/exp_data/flame/components/checkpoint.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
from dataclasses import dataclass, field
|
| 8 |
+
from datetime import timedelta
|
| 9 |
+
from io import BytesIO
|
| 10 |
+
from typing import Any, Dict, List
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
from torch.distributed.checkpoint.stateful import Stateful
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
@dataclass
|
| 17 |
+
class TrainState(Stateful):
|
| 18 |
+
step: int = 0
|
| 19 |
+
skipped_step: int = 0
|
| 20 |
+
token: int = 0
|
| 21 |
+
elapsed: timedelta = timedelta(0)
|
| 22 |
+
global_avg_losses: List[float] = field(default_factory=list)
|
| 23 |
+
global_max_losses: List[float] = field(default_factory=list)
|
| 24 |
+
log_steps: List[int] = field(default_factory=list)
|
| 25 |
+
|
| 26 |
+
def state_dict(self) -> Dict[str, Any]:
|
| 27 |
+
# Only checkpoint global_avg_losses and global_max_losses per log frequency
|
| 28 |
+
# to avoid sync overhead in every iteration.
|
| 29 |
+
global_avg_losses_bytes = BytesIO()
|
| 30 |
+
torch.save(self.global_avg_losses, global_avg_losses_bytes)
|
| 31 |
+
global_max_losses_bytes = BytesIO()
|
| 32 |
+
torch.save(self.global_max_losses, global_max_losses_bytes)
|
| 33 |
+
log_steps_bytes = BytesIO()
|
| 34 |
+
torch.save(self.log_steps, log_steps_bytes)
|
| 35 |
+
return {
|
| 36 |
+
"step": torch.tensor(self.step, dtype=torch.int32),
|
| 37 |
+
"skipped_step": torch.tensor(self.skipped_step, dtype=torch.int32),
|
| 38 |
+
"token": torch.tensor(self.token, dtype=torch.int64),
|
| 39 |
+
"elapsed": self.elapsed,
|
| 40 |
+
"global_avg_losses": global_avg_losses_bytes,
|
| 41 |
+
"global_max_losses": global_max_losses_bytes,
|
| 42 |
+
"log_steps": log_steps_bytes,
|
| 43 |
+
}
|
| 44 |
+
|
| 45 |
+
def load_state_dict(self, state_dict) -> None:
|
| 46 |
+
self.step = state_dict["step"].item()
|
| 47 |
+
self.skipped_step = state_dict.get("skipped_step", 0).item()
|
| 48 |
+
self.token = state_dict["token"].item()
|
| 49 |
+
self.elapsed = state_dict["elapsed"]
|
| 50 |
+
state_dict["global_avg_losses"].seek(0)
|
| 51 |
+
self.global_avg_losses = torch.load(
|
| 52 |
+
state_dict["global_avg_losses"], weights_only=False
|
| 53 |
+
)
|
| 54 |
+
state_dict["global_max_losses"].seek(0)
|
| 55 |
+
self.global_max_losses = torch.load(
|
| 56 |
+
state_dict["global_max_losses"], weights_only=False
|
| 57 |
+
)
|
| 58 |
+
state_dict["log_steps"].seek(0)
|
| 59 |
+
self.log_steps = torch.load(state_dict["log_steps"], weights_only=False)
|
transformer_pp_1b_c4/transformer_pp_1b_c4_valc4_soap_pdim2048_pfreq10_lr3e_3_b1_0_9_b2_0_95_eps_1e_15_20260508_191338/exp_data/flame/config_manager.py
ADDED
|
@@ -0,0 +1,981 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import argparse
|
| 8 |
+
import sys
|
| 9 |
+
from collections import defaultdict
|
| 10 |
+
from typing import Tuple
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
|
| 14 |
+
try:
|
| 15 |
+
import tomllib
|
| 16 |
+
except ModuleNotFoundError:
|
| 17 |
+
import tomli as tomllib
|
| 18 |
+
|
| 19 |
+
from torchtitan.tools.logging import logger
|
| 20 |
+
|
| 21 |
+
TORCH_DTYPE_MAP = {
|
| 22 |
+
"float16": torch.float16,
|
| 23 |
+
"float32": torch.float32,
|
| 24 |
+
"bfloat16": torch.bfloat16,
|
| 25 |
+
}
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def string_list(raw_arg):
|
| 29 |
+
"""Comma-separated string list argument."""
|
| 30 |
+
return [s.strip() for s in raw_arg.split(",") if s.strip()]
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def check_string_list_argument(args_dict: dict[str, any], fullargname: str):
|
| 34 |
+
section, name = fullargname.split(".")
|
| 35 |
+
# Split string list which are still raw strings.
|
| 36 |
+
if (
|
| 37 |
+
section in args_dict
|
| 38 |
+
and name in args_dict[section]
|
| 39 |
+
and isinstance(args_dict[section][name], str)
|
| 40 |
+
):
|
| 41 |
+
sec = args_dict[section]
|
| 42 |
+
sec[name] = string_list(sec[name])
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class JobConfig:
|
| 46 |
+
"""
|
| 47 |
+
A helper class to manage the train configuration.
|
| 48 |
+
Semantics:
|
| 49 |
+
- Default config is loaded from a toml file. If no toml file is provided,
|
| 50 |
+
then the default config is loaded from argparse defaults.
|
| 51 |
+
- if toml file has missing keys, they are filled with argparse defaults.
|
| 52 |
+
- if additional explicit cmd args are provided in addition to the toml
|
| 53 |
+
file, they will override the toml config and the argparse defaults
|
| 54 |
+
|
| 55 |
+
precedence order: cmdline > toml > argparse default
|
| 56 |
+
|
| 57 |
+
Arg parsing semantics:
|
| 58 |
+
|
| 59 |
+
Each argument starts with <prefix>_ which is the section name in the toml file
|
| 60 |
+
followed by name of the option in the toml file. For ex,
|
| 61 |
+
model.name translates to:
|
| 62 |
+
[model]
|
| 63 |
+
name
|
| 64 |
+
in the toml file
|
| 65 |
+
"""
|
| 66 |
+
|
| 67 |
+
def __init__(self):
|
| 68 |
+
self.args_dict = None
|
| 69 |
+
# main parser
|
| 70 |
+
self.parser = argparse.ArgumentParser(description="torchtitan arg parser.")
|
| 71 |
+
|
| 72 |
+
self.parser.add_argument(
|
| 73 |
+
"--job.config_file",
|
| 74 |
+
type=str,
|
| 75 |
+
default=None,
|
| 76 |
+
help="Job config file",
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
# job level configs
|
| 80 |
+
self.parser.add_argument(
|
| 81 |
+
"--job.dump_folder",
|
| 82 |
+
type=str,
|
| 83 |
+
default="./torchtitan/outputs",
|
| 84 |
+
help="Folder to dump job outputs",
|
| 85 |
+
)
|
| 86 |
+
self.parser.add_argument(
|
| 87 |
+
"--job.description",
|
| 88 |
+
type=str,
|
| 89 |
+
default="default job",
|
| 90 |
+
help="Description of the job",
|
| 91 |
+
)
|
| 92 |
+
self.parser.add_argument(
|
| 93 |
+
"--job.use_for_integration_test",
|
| 94 |
+
action="store_true",
|
| 95 |
+
help="Add this config to the integration test suite",
|
| 96 |
+
)
|
| 97 |
+
self.parser.add_argument(
|
| 98 |
+
"--job.print_args",
|
| 99 |
+
action="store_true",
|
| 100 |
+
help="Print the args to terminal",
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
# model configs
|
| 104 |
+
self.parser.add_argument(
|
| 105 |
+
"--model.name",
|
| 106 |
+
type=str,
|
| 107 |
+
default="fla",
|
| 108 |
+
help="Which model to train",
|
| 109 |
+
)
|
| 110 |
+
self.parser.add_argument(
|
| 111 |
+
"--model.config",
|
| 112 |
+
type=str,
|
| 113 |
+
default="fla-hub/transformer-1.3B-100B",
|
| 114 |
+
help="Path to the model config",
|
| 115 |
+
)
|
| 116 |
+
self.parser.add_argument(
|
| 117 |
+
"--model.tokenizer_path",
|
| 118 |
+
type=str,
|
| 119 |
+
default="fla-hub/transformer-1.3B-100B",
|
| 120 |
+
help="Tokenizer path",
|
| 121 |
+
)
|
| 122 |
+
self.parser.add_argument(
|
| 123 |
+
"--model.converters",
|
| 124 |
+
type=string_list,
|
| 125 |
+
nargs="+",
|
| 126 |
+
default=[],
|
| 127 |
+
help="""
|
| 128 |
+
Comma separated list of converters to apply to the model.
|
| 129 |
+
For instance, the `float8` converter swaps `torch.nn.Linear`
|
| 130 |
+
with `Float8Linear`. This feature requires you to install 'torchao'
|
| 131 |
+
which can be found here: https://github.com/pytorch/ao
|
| 132 |
+
""",
|
| 133 |
+
)
|
| 134 |
+
self.parser.add_argument(
|
| 135 |
+
"--model.print_after_conversion",
|
| 136 |
+
action="store_true",
|
| 137 |
+
help="""
|
| 138 |
+
If true, model definition will be printed to stdout after all model
|
| 139 |
+
converters have been applied.
|
| 140 |
+
""",
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
# profiling configs
|
| 144 |
+
self.parser.add_argument(
|
| 145 |
+
"--profiling.enable_profiling",
|
| 146 |
+
action="store_true",
|
| 147 |
+
help="Whether to enable pytorch profiler",
|
| 148 |
+
)
|
| 149 |
+
self.parser.add_argument(
|
| 150 |
+
"--profiling.save_traces_folder",
|
| 151 |
+
type=str,
|
| 152 |
+
default="profile_traces",
|
| 153 |
+
help="Trace files location",
|
| 154 |
+
)
|
| 155 |
+
self.parser.add_argument(
|
| 156 |
+
"--profiling.profile_freq",
|
| 157 |
+
type=int,
|
| 158 |
+
default=10,
|
| 159 |
+
help="How often to collect profiler traces, in iterations",
|
| 160 |
+
)
|
| 161 |
+
self.parser.add_argument(
|
| 162 |
+
"--profiling.enable_memory_snapshot",
|
| 163 |
+
action="store_true",
|
| 164 |
+
help="Whether to dump memory snapshot",
|
| 165 |
+
)
|
| 166 |
+
self.parser.add_argument(
|
| 167 |
+
"--profiling.save_memory_snapshot_folder",
|
| 168 |
+
type=str,
|
| 169 |
+
default="memory_snapshot",
|
| 170 |
+
help="Memeory snapshot files location",
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
+
# optimizer configs
|
| 174 |
+
self.parser.add_argument(
|
| 175 |
+
"--optimizer.name", type=str, default="AdamW", help="Optimizer to use"
|
| 176 |
+
)
|
| 177 |
+
self.parser.add_argument(
|
| 178 |
+
"--optimizer.eps",
|
| 179 |
+
type=float,
|
| 180 |
+
default=1e-8,
|
| 181 |
+
help="Epsilon value for the optimizer.",
|
| 182 |
+
)
|
| 183 |
+
self.parser.add_argument(
|
| 184 |
+
"--optimizer.lr", type=float, default=8e-4, help="Learning rate to use"
|
| 185 |
+
)
|
| 186 |
+
self.parser.add_argument(
|
| 187 |
+
"--optimizer.beta1", type=float, default=0.9,
|
| 188 |
+
help="Exponential moving average hyperparameters to use"
|
| 189 |
+
)
|
| 190 |
+
self.parser.add_argument(
|
| 191 |
+
"--optimizer.beta2", type=float, default=0.95,
|
| 192 |
+
help="Exponential moving average hyperparameters to use"
|
| 193 |
+
)
|
| 194 |
+
self.parser.add_argument(
|
| 195 |
+
"--optimizer.weight_decay", type=float, default=0.1,
|
| 196 |
+
help="Weight decay to use"
|
| 197 |
+
)
|
| 198 |
+
self.parser.add_argument(
|
| 199 |
+
"--optimizer.implementation",
|
| 200 |
+
type=str,
|
| 201 |
+
default="fused",
|
| 202 |
+
choices=["for-loop", "foreach", "fused"],
|
| 203 |
+
help="""
|
| 204 |
+
Specify which optimizer implementation to use:
|
| 205 |
+
- 'fused': Use fused implementation (CUDA only) for best performance.
|
| 206 |
+
- 'foreach': Use some horizontal fusion of tensors for better performance.
|
| 207 |
+
- 'for-loop': Use the default implementation for the optimizer (slowest).
|
| 208 |
+
- more info: https://pytorch.org/docs/stable/optim.html
|
| 209 |
+
""",
|
| 210 |
+
)
|
| 211 |
+
self.parser.add_argument(
|
| 212 |
+
"--optimizer.early_step_in_backward",
|
| 213 |
+
action="store_true",
|
| 214 |
+
help="""
|
| 215 |
+
Whether to apply optimizer in the backward. Caution, optimizer_in_backward
|
| 216 |
+
is not compatible with gradients clipping, users should not call
|
| 217 |
+
register_post_accumulate_grad_hook after the optimizer is built.""",
|
| 218 |
+
)
|
| 219 |
+
|
| 220 |
+
# lr scheduler configs
|
| 221 |
+
self.parser.add_argument(
|
| 222 |
+
"--lr_scheduler.warmup_steps",
|
| 223 |
+
type=int,
|
| 224 |
+
default=200,
|
| 225 |
+
help="Steps for lr scheduler warmup, normally 1/5 of --training.steps",
|
| 226 |
+
)
|
| 227 |
+
self.parser.add_argument(
|
| 228 |
+
"--lr_scheduler.decay_ratio",
|
| 229 |
+
type=float,
|
| 230 |
+
default=None,
|
| 231 |
+
help="""
|
| 232 |
+
Controls the proportion of the training steps allocated to the learning rate decay phase.
|
| 233 |
+
|
| 234 |
+
If `None`, the learning rate will begin decaying immediately after the warmup period.
|
| 235 |
+
Otherwise, the learning rate will remain stable after the warmup period and
|
| 236 |
+
only start decaying during the last `decay_ratio` portion of the total training steps.
|
| 237 |
+
|
| 238 |
+
This is known as the Warmup-Stable-Decay (WSD) schedule, as described in https://arxiv.org/abs/2404.06395.
|
| 239 |
+
""",
|
| 240 |
+
)
|
| 241 |
+
self.parser.add_argument(
|
| 242 |
+
"--lr_scheduler.decay_type",
|
| 243 |
+
type=str,
|
| 244 |
+
default="linear",
|
| 245 |
+
choices=["linear", "sqrt", "cosine"],
|
| 246 |
+
help="""
|
| 247 |
+
Learning rate decay type to use during training:
|
| 248 |
+
- 'linear': linearly decays learning rate from initial to final value
|
| 249 |
+
- 'sqrt': decays learning rate following a 1 minus square root curve
|
| 250 |
+
- 'cosine': smoothly decays learning rate following a cosine curve
|
| 251 |
+
""",
|
| 252 |
+
)
|
| 253 |
+
self.parser.add_argument(
|
| 254 |
+
"--lr_scheduler.lr_min",
|
| 255 |
+
type=float,
|
| 256 |
+
default=0.0,
|
| 257 |
+
help="""
|
| 258 |
+
Min lr ratio for lr scheduler.
|
| 259 |
+
|
| 260 |
+
If provided, the range of decay factor is scaled from 1 to `lr_min`
|
| 261 |
+
to ensure the learning rate does not drop below `optimizer.lr * lr_scheduler.lr_min`.
|
| 262 |
+
""",
|
| 263 |
+
)
|
| 264 |
+
|
| 265 |
+
# training configs
|
| 266 |
+
self.parser.add_argument(
|
| 267 |
+
"--training.batch_size", type=int, default=8, help="Batch size"
|
| 268 |
+
)
|
| 269 |
+
self.parser.add_argument(
|
| 270 |
+
"--training.seq_len", type=int, default=2048, help="Sequence length"
|
| 271 |
+
)
|
| 272 |
+
self.parser.add_argument(
|
| 273 |
+
"--training.context_len",
|
| 274 |
+
type=int,
|
| 275 |
+
default=2048,
|
| 276 |
+
help="Max length allowed for each sequence",
|
| 277 |
+
)
|
| 278 |
+
self.parser.add_argument(
|
| 279 |
+
"--training.varlen",
|
| 280 |
+
action="store_true",
|
| 281 |
+
help="Whether to take sequences of variable length as input",
|
| 282 |
+
)
|
| 283 |
+
self.parser.add_argument(
|
| 284 |
+
"--training.gradient_accumulation_steps",
|
| 285 |
+
type=int,
|
| 286 |
+
default=1,
|
| 287 |
+
help="Number of steps to accumulate gradients before updating parameters",
|
| 288 |
+
)
|
| 289 |
+
self.parser.add_argument(
|
| 290 |
+
"--training.steps",
|
| 291 |
+
type=int,
|
| 292 |
+
default=10000,
|
| 293 |
+
help="How many train steps to run",
|
| 294 |
+
)
|
| 295 |
+
self.parser.add_argument(
|
| 296 |
+
"--training.max_norm",
|
| 297 |
+
type=float,
|
| 298 |
+
default=1.0,
|
| 299 |
+
help="Max norm for gradient clipping",
|
| 300 |
+
)
|
| 301 |
+
self.parser.add_argument(
|
| 302 |
+
"--training.skip_nan_inf",
|
| 303 |
+
action="store_true",
|
| 304 |
+
help="Skip batch updates when NaN or INF gradients are encountered during training",
|
| 305 |
+
)
|
| 306 |
+
self.parser.add_argument(
|
| 307 |
+
"--training.dataset",
|
| 308 |
+
default="HuggingFaceFW/fineweb-edu",
|
| 309 |
+
help="Dataset to use, with comma separated values",
|
| 310 |
+
)
|
| 311 |
+
self.parser.add_argument(
|
| 312 |
+
"--training.dataset_name",
|
| 313 |
+
default=None,
|
| 314 |
+
help="The name of the dataset config, with comma separated values if provided",
|
| 315 |
+
)
|
| 316 |
+
self.parser.add_argument(
|
| 317 |
+
"--training.dataset_split",
|
| 318 |
+
default=None,
|
| 319 |
+
help="Dataset split to use, with comma separated values if provided",
|
| 320 |
+
)
|
| 321 |
+
self.parser.add_argument(
|
| 322 |
+
"--training.data_dir",
|
| 323 |
+
default=None,
|
| 324 |
+
help="Data dirs to use, with comma separated values if provided",
|
| 325 |
+
)
|
| 326 |
+
self.parser.add_argument(
|
| 327 |
+
"--training.data_files",
|
| 328 |
+
default=None,
|
| 329 |
+
help="Data files to use, with comma separated values if provided",
|
| 330 |
+
)
|
| 331 |
+
self.parser.add_argument(
|
| 332 |
+
"--training.data_probs",
|
| 333 |
+
default=None,
|
| 334 |
+
help="Data sampling probabilities, with comma separated values if provided",
|
| 335 |
+
)
|
| 336 |
+
self.parser.add_argument(
|
| 337 |
+
"--training.streaming",
|
| 338 |
+
action="store_true",
|
| 339 |
+
help="Whether to load dataset in streaming mode, used for huge dataset",
|
| 340 |
+
)
|
| 341 |
+
self.parser.add_argument(
|
| 342 |
+
"--training.num_workers",
|
| 343 |
+
type=int,
|
| 344 |
+
default=32,
|
| 345 |
+
help="Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.",
|
| 346 |
+
)
|
| 347 |
+
self.parser.add_argument(
|
| 348 |
+
"--training.prefetch_factor",
|
| 349 |
+
type=int,
|
| 350 |
+
default=2,
|
| 351 |
+
help="Number of batches loaded in advance by each worker."
|
| 352 |
+
"2 means there will be a total of 2 * num_workers batches prefetched across all workers.",
|
| 353 |
+
)
|
| 354 |
+
self.parser.add_argument(
|
| 355 |
+
"--training.data_parallel_replicate_degree",
|
| 356 |
+
type=int,
|
| 357 |
+
default=1,
|
| 358 |
+
help="""
|
| 359 |
+
The `data_parallel_replicate_degree` argument specifies the degree of
|
| 360 |
+
data parallelism for weight replication. When this value is greater
|
| 361 |
+
than 1, weights will be replicated across `data_parallel_replicate_degree`
|
| 362 |
+
ranks. If `data_parallel_shard_degree` is also greater than 1, the parallelism
|
| 363 |
+
method used is HSDP (Hybrid Sharded Data Parallelism). Otherwise, the
|
| 364 |
+
parallelism method used is DDP (Distributed Data Parallelism).
|
| 365 |
+
1 means disabled.""",
|
| 366 |
+
)
|
| 367 |
+
self.parser.add_argument(
|
| 368 |
+
"--training.data_parallel_shard_degree",
|
| 369 |
+
type=int,
|
| 370 |
+
default=-1,
|
| 371 |
+
help="""
|
| 372 |
+
The `data_parallel_shard_degree` argument specifies the degree of data
|
| 373 |
+
parallelism for weight sharding. When this value is greater than 1, weights
|
| 374 |
+
will be sharded across `data_parallel_shard_degree` ranks. If
|
| 375 |
+
`data_parallel_replicate_degree` is also greater than 1, the parallelism
|
| 376 |
+
method used is HSDP (Hybrid Sharded Data Parallelism). Otherwise, the
|
| 377 |
+
parallelism method used is FSDP (Fully Sharded Data Parallelism).
|
| 378 |
+
|
| 379 |
+
-1 means leftover ranks will be used (After DP_REPLICATE/SP/PP). Note that
|
| 380 |
+
only `data_parallel_shard_degree` can be negative. 1 means disabled.""",
|
| 381 |
+
)
|
| 382 |
+
self.parser.add_argument(
|
| 383 |
+
"--training.enable_cpu_offload",
|
| 384 |
+
action="store_true",
|
| 385 |
+
help="""
|
| 386 |
+
Whether to apply CPU offloading of parameters, gradients, and optimizer states in FSDP""",
|
| 387 |
+
)
|
| 388 |
+
self.parser.add_argument(
|
| 389 |
+
"--training.tensor_parallel_degree",
|
| 390 |
+
type=int,
|
| 391 |
+
default=1,
|
| 392 |
+
help="Tensor Parallelism degree. 1 means disabled.",
|
| 393 |
+
)
|
| 394 |
+
self.parser.add_argument(
|
| 395 |
+
"--training.disable_loss_parallel",
|
| 396 |
+
action="store_true",
|
| 397 |
+
help="Whether to apply loss parallel when sequence parallel is enabled",
|
| 398 |
+
)
|
| 399 |
+
self.parser.add_argument(
|
| 400 |
+
"--training.fsdp_reshard_after_forward",
|
| 401 |
+
type=str,
|
| 402 |
+
default="default",
|
| 403 |
+
choices=["default", "always", "never"],
|
| 404 |
+
help="""
|
| 405 |
+
`reshard_after_forward` specifies the policy for applying `reshard_after_forward`
|
| 406 |
+
within an FSDP setup. `reshard_after_forward` controls parameter behavior after forward,
|
| 407 |
+
trading off memory and communication. See torch's `fully_shard` API for more documentation
|
| 408 |
+
on `reshard_after_forward`.
|
| 409 |
+
The supported policies include "default", "always" and "never":
|
| 410 |
+
- "default" applies default resharding behavior, implementing "smart defaults" for known optimal
|
| 411 |
+
scenarios.
|
| 412 |
+
- "always" will enable `reshard_after_forward` for all forward passes.
|
| 413 |
+
- "never" will disable `reshard_after_forward` for all forward passes.
|
| 414 |
+
""",
|
| 415 |
+
)
|
| 416 |
+
self.parser.add_argument(
|
| 417 |
+
"--training.mixed_precision_param",
|
| 418 |
+
type=str,
|
| 419 |
+
default="bfloat16",
|
| 420 |
+
choices=["bfloat16", "float32"],
|
| 421 |
+
help="""
|
| 422 |
+
torch dtype to use for parameters when applying mixed precision via fully_shard or torch.autocast.
|
| 423 |
+
This feature takes effect via fully_shard when data_parallel_shard_degree > 1 or
|
| 424 |
+
context_parallel_degree > 1; it takes effect via torch.autocast when data_replicate_degree >= 1
|
| 425 |
+
and no other parallelism is enabled, i.e. under DDP or single-device training.
|
| 426 |
+
""",
|
| 427 |
+
)
|
| 428 |
+
self.parser.add_argument(
|
| 429 |
+
"--training.mixed_precision_reduce",
|
| 430 |
+
type=str,
|
| 431 |
+
default="float32",
|
| 432 |
+
choices=["float32"],
|
| 433 |
+
help="""
|
| 434 |
+
torch dtype to use for reductions when applying mixed precision via FSDP.
|
| 435 |
+
This feature only takes effect when data_parallel_shard_degree > 1
|
| 436 |
+
""",
|
| 437 |
+
)
|
| 438 |
+
self.parser.add_argument(
|
| 439 |
+
"--training.compile",
|
| 440 |
+
action="store_true",
|
| 441 |
+
help="Whether to compile the model",
|
| 442 |
+
)
|
| 443 |
+
self.parser.add_argument(
|
| 444 |
+
"--training.gc_freq",
|
| 445 |
+
type=int,
|
| 446 |
+
default=50,
|
| 447 |
+
help="Python garbage control scheduling interval, in steps",
|
| 448 |
+
)
|
| 449 |
+
self.parser.add_argument(
|
| 450 |
+
"--training.seed",
|
| 451 |
+
type=int,
|
| 452 |
+
default=42,
|
| 453 |
+
help="Choose the base RNG seed used for training",
|
| 454 |
+
)
|
| 455 |
+
self.parser.add_argument(
|
| 456 |
+
"--training.deterministic",
|
| 457 |
+
action="store_true",
|
| 458 |
+
help="Use deterministic algorithms wherever possible, may be slower",
|
| 459 |
+
)
|
| 460 |
+
# ------ jinxin ------ #
|
| 461 |
+
self.parser.add_argument(
|
| 462 |
+
"--training.val_times",
|
| 463 |
+
type=int,
|
| 464 |
+
default=0,
|
| 465 |
+
help="Number of times to evaluate val PPL during training. 0 means no intermediate eval. "
|
| 466 |
+
"e.g. 10 means evaluate every (total_steps // 10) steps.",
|
| 467 |
+
)
|
| 468 |
+
self.parser.add_argument(
|
| 469 |
+
"--training.val_data_dir",
|
| 470 |
+
type=str,
|
| 471 |
+
default=None,
|
| 472 |
+
help="Path to the validation data directory containing parquet files. "
|
| 473 |
+
"If None, defaults to 'data/wiki_val/' relative to cwd.",
|
| 474 |
+
)
|
| 475 |
+
# metrics configs
|
| 476 |
+
self.parser.add_argument(
|
| 477 |
+
"--metrics.log_freq",
|
| 478 |
+
type=int,
|
| 479 |
+
default=10,
|
| 480 |
+
help="How often to log metrics to TensorBoard, in iterations",
|
| 481 |
+
)
|
| 482 |
+
self.parser.add_argument(
|
| 483 |
+
"--metrics.enable_tensorboard",
|
| 484 |
+
action="store_true",
|
| 485 |
+
help="Whether to log metrics to TensorBoard",
|
| 486 |
+
)
|
| 487 |
+
self.parser.add_argument(
|
| 488 |
+
"--metrics.disable_color_printing",
|
| 489 |
+
action="store_true",
|
| 490 |
+
help="Whether to disable color printing in logs",
|
| 491 |
+
)
|
| 492 |
+
self.parser.add_argument(
|
| 493 |
+
"--metrics.save_tb_folder",
|
| 494 |
+
type=str,
|
| 495 |
+
default="tb",
|
| 496 |
+
help="Folder to dump TensorBoard states",
|
| 497 |
+
)
|
| 498 |
+
self.parser.add_argument(
|
| 499 |
+
"--metrics.save_for_all_ranks",
|
| 500 |
+
action="store_true",
|
| 501 |
+
default=False,
|
| 502 |
+
help="""
|
| 503 |
+
Whether to save TensorBoard/Wandb metrics only for rank 0 or for all ranks.
|
| 504 |
+
When this option is False and pipeline_parallel_degree is > 1, the metrics
|
| 505 |
+
component uses the 0th rank of the last stage pipeline group, which is the
|
| 506 |
+
only stage that computes loss metrics.
|
| 507 |
+
""",
|
| 508 |
+
)
|
| 509 |
+
self.parser.add_argument(
|
| 510 |
+
"--metrics.enable_wandb",
|
| 511 |
+
action="store_true",
|
| 512 |
+
help="Whether to log metrics to Weights & Biases",
|
| 513 |
+
)
|
| 514 |
+
self.parser.add_argument(
|
| 515 |
+
"--no-metrics.enable_wandb",
|
| 516 |
+
dest="metrics.enable_wandb",
|
| 517 |
+
action="store_false",
|
| 518 |
+
help="Disable Weights & Biases logging (e.g. to avoid disk quota on ~/.cache/wandb)",
|
| 519 |
+
)
|
| 520 |
+
|
| 521 |
+
self.parser.add_argument(
|
| 522 |
+
"--experimental.enable_async_tensor_parallel",
|
| 523 |
+
action="store_true",
|
| 524 |
+
help="Whether to apply async tensor parallel (currently only effective when compile is enabled)",
|
| 525 |
+
)
|
| 526 |
+
self.parser.add_argument(
|
| 527 |
+
"--experimental.pipeline_parallel_degree",
|
| 528 |
+
type=int,
|
| 529 |
+
default=1,
|
| 530 |
+
help="""
|
| 531 |
+
Pipeline Parallelism degree, or number of ranks. 1 means disabled.
|
| 532 |
+
If using looped schedules, this still specifies the number of physical ranks, not the number
|
| 533 |
+
of stages. Stages per rank are inferred from split points degree, and schedule.""",
|
| 534 |
+
)
|
| 535 |
+
self.parser.add_argument(
|
| 536 |
+
"--experimental.pipeline_parallel_split_points",
|
| 537 |
+
type=string_list,
|
| 538 |
+
nargs="+",
|
| 539 |
+
default=[],
|
| 540 |
+
help="""
|
| 541 |
+
Specify comma-separated names of modules to use as the beginning of a split point.
|
| 542 |
+
|
| 543 |
+
e.g. "layers.0,layers.2" will cause the model to be split into 3 stages,
|
| 544 |
+
the first containing all the layers up to layers.0,
|
| 545 |
+
the second containing layers.0 and up to layers.2,
|
| 546 |
+
the third containing layers.2 and all the remaining layers.
|
| 547 |
+
|
| 548 |
+
Note: fully-automated splitting may be enabled in the future,
|
| 549 |
+
but currently the split points must be specified manually.""",
|
| 550 |
+
)
|
| 551 |
+
self.parser.add_argument(
|
| 552 |
+
"--experimental.pipeline_parallel_schedule",
|
| 553 |
+
type=str,
|
| 554 |
+
default="1F1B",
|
| 555 |
+
help="""
|
| 556 |
+
Specify the Pipeline Parallel schedule to use. The supported schedules are:
|
| 557 |
+
https://github.com/pytorch/pytorch/blob/de4c2a3b4e89d96334dc678d1c3f2ae51a6630a0/torch/distributed/pipelining/schedules.py#L2161.
|
| 558 |
+
The schedule must be compatible with the split points and stages_per_rank.
|
| 559 |
+
|
| 560 |
+
Looped schedules (e.g. Interleaved1F1B) require specifying pipeline_parallel_degree = number of ranks,
|
| 561 |
+
and split_points = number of stages - 1
|
| 562 |
+
""",
|
| 563 |
+
)
|
| 564 |
+
self.parser.add_argument(
|
| 565 |
+
"--experimental.pipeline_parallel_schedule_csv",
|
| 566 |
+
type=str,
|
| 567 |
+
default="",
|
| 568 |
+
help="""
|
| 569 |
+
Specify the path to the pipeline parallel schedule csv file to use.
|
| 570 |
+
The pipeline_parallel_schedule argument must be either
|
| 571 |
+
PipelineScheduleSingle, PipelineScheduleMulti, or _PipelineScheduleRuntime.
|
| 572 |
+
""",
|
| 573 |
+
)
|
| 574 |
+
|
| 575 |
+
self.parser.add_argument(
|
| 576 |
+
"--experimental.pipeline_parallel_microbatches",
|
| 577 |
+
type=int,
|
| 578 |
+
default=None,
|
| 579 |
+
help="""
|
| 580 |
+
How many microbatches to split the global training batch into when using pipeline parallelism.
|
| 581 |
+
|
| 582 |
+
The global training batch size must be evenly divisible by the number of microbatches.
|
| 583 |
+
|
| 584 |
+
The default value will be the number of pipeline stages, if unspecified.
|
| 585 |
+
""",
|
| 586 |
+
)
|
| 587 |
+
self.parser.add_argument(
|
| 588 |
+
"--experimental.enable_compiled_autograd",
|
| 589 |
+
action="store_true",
|
| 590 |
+
help="Enable CompiledAutograd to compile the backward.",
|
| 591 |
+
)
|
| 592 |
+
self.parser.add_argument(
|
| 593 |
+
"--experimental.context_parallel_degree",
|
| 594 |
+
type=int,
|
| 595 |
+
default=1,
|
| 596 |
+
help="Context parallelism degree. 1 means disabled.",
|
| 597 |
+
)
|
| 598 |
+
self.parser.add_argument(
|
| 599 |
+
"--experimental.context_parallel_rotate_method",
|
| 600 |
+
type=str,
|
| 601 |
+
default="allgather",
|
| 602 |
+
help="""
|
| 603 |
+
The collective to use in context parallel SDPA for kv shards exchange.
|
| 604 |
+
|
| 605 |
+
'allgather' means to all-gather all kv shards on ranks after the first sub-SDPA computation,
|
| 606 |
+
|
| 607 |
+
'alltoall' means to all-to-all shuffle the kv shards.
|
| 608 |
+
|
| 609 |
+
The default value is 'allgather'.
|
| 610 |
+
""",
|
| 611 |
+
)
|
| 612 |
+
# I'm not particularly fond of this. Users can choose to write their own wrapper
|
| 613 |
+
# module and import TorchTitan training loop and execute it, which look cleaner.
|
| 614 |
+
# One reason to provide this option is to allow users to use the existing run script.
|
| 615 |
+
# While the script is pretty trivial now, we may add more logic when integrating
|
| 616 |
+
# with TorchFT.
|
| 617 |
+
# This option is subject to change and may be deleted in the future.
|
| 618 |
+
self.parser.add_argument(
|
| 619 |
+
"--experimental.custom_model_path",
|
| 620 |
+
type=str,
|
| 621 |
+
default="",
|
| 622 |
+
help="""
|
| 623 |
+
The --custom_model_path option allows to specify a custom path to a model module
|
| 624 |
+
that is not natively implemented within TorchTitan.
|
| 625 |
+
Acceptable values are the file system path to the module (e.g., my_models/model_x)
|
| 626 |
+
dotted import module (e.g., some_package.model_x).
|
| 627 |
+
""",
|
| 628 |
+
)
|
| 629 |
+
# checkpointing configs
|
| 630 |
+
self.parser.add_argument(
|
| 631 |
+
"--checkpoint.enable_checkpoint",
|
| 632 |
+
action="store_true",
|
| 633 |
+
help="Whether to enable checkpoint",
|
| 634 |
+
)
|
| 635 |
+
self.parser.add_argument(
|
| 636 |
+
"--checkpoint.folder",
|
| 637 |
+
type=str,
|
| 638 |
+
default="checkpoint",
|
| 639 |
+
help="""
|
| 640 |
+
The folder to store the checkpoints.
|
| 641 |
+
When enable_checkpoint is set to true, checkpoints will be in {--job.dump_folder}/{--checkpoint.folder}.
|
| 642 |
+
""",
|
| 643 |
+
)
|
| 644 |
+
self.parser.add_argument(
|
| 645 |
+
"--checkpoint.initial_load_path", type=str, default=None,
|
| 646 |
+
help="""
|
| 647 |
+
This option specifies the path to the initial checkpoint to load, which is
|
| 648 |
+
particularly useful for resuming training from a previous run with a
|
| 649 |
+
different output path or when loading a checkpoint from a pre-trained model.
|
| 650 |
+
If the checkpoint folder for the current run is not empty,
|
| 651 |
+
located at {--job.dump_folder}/{--checkpoint.folder}, this option will be ignored.
|
| 652 |
+
This feature allows users to load an initial checkpoint from a different folder and
|
| 653 |
+
continue training, saving new checkpoints to the specified folder without affecting
|
| 654 |
+
the existing ones.
|
| 655 |
+
|
| 656 |
+
Note that the path should contain the full path to the checkpoint folder,
|
| 657 |
+
including the step number, if any; for example,
|
| 658 |
+
"//pre_train/checkpoints/llama3/llama3_8b/step_10000".
|
| 659 |
+
"""
|
| 660 |
+
)
|
| 661 |
+
self.parser.add_argument(
|
| 662 |
+
"--checkpoint.initial_load_model_weights_only",
|
| 663 |
+
dest='checkpoint.initial_load_model_weights_only', action="store_true", default=True,
|
| 664 |
+
help="""
|
| 665 |
+
This option specifies if only the model weights should be loaded during the initial
|
| 666 |
+
checkpoint load. The option is only used when `initial_load_path` is specified, and
|
| 667 |
+
only applies to a model_weights_only checkpoint. Loading a periodic checkpoint
|
| 668 |
+
may lead to unexpected behavior if this option is set to True.
|
| 669 |
+
If False, the checkpoint at `initial_load_path` is treated as a standard training
|
| 670 |
+
checkpoint, including optimizer and training states.
|
| 671 |
+
The default setting for this option is True. Note that you will have to use
|
| 672 |
+
`--checkpoint.no_initial_load_model_weights_only` to override the default setting.
|
| 673 |
+
"""
|
| 674 |
+
)
|
| 675 |
+
self.parser.add_argument(
|
| 676 |
+
"--checkpoint.no_initial_load_model_weights_only",
|
| 677 |
+
dest='checkpoint.initial_load_model_weights_only', action="store_false",
|
| 678 |
+
)
|
| 679 |
+
self.parser.add_argument(
|
| 680 |
+
"--checkpoint.interval",
|
| 681 |
+
type=int,
|
| 682 |
+
default=500,
|
| 683 |
+
help="Checkpointing interval in steps.",
|
| 684 |
+
)
|
| 685 |
+
self.parser.add_argument(
|
| 686 |
+
"--checkpoint.last_save_model_weights_only",
|
| 687 |
+
action="store_true",
|
| 688 |
+
help="""
|
| 689 |
+
When last_save_model_weights_only=True, only model weights will be saved at the end of training,
|
| 690 |
+
the last save. With this, checkpoints can be loaded using `torch.load(..., weights_only=True)`
|
| 691 |
+
after conversion. When last_save_model_weights_only=False, the full checkpoint will be saved.
|
| 692 |
+
A full checkpoint includes model, optimizer and train_state, which can be used to resume training.
|
| 693 |
+
The default value is false.
|
| 694 |
+
""",
|
| 695 |
+
)
|
| 696 |
+
self.parser.add_argument(
|
| 697 |
+
"--checkpoint.export_dtype",
|
| 698 |
+
type=str,
|
| 699 |
+
default="float32",
|
| 700 |
+
choices=["float16", "bfloat16", "float32"],
|
| 701 |
+
help="""
|
| 702 |
+
Converts to the specified precision when training completes and model_weights_only=true.
|
| 703 |
+
Currently supports float32, float16, and bfloat16.
|
| 704 |
+
The default value is float32.
|
| 705 |
+
""",
|
| 706 |
+
)
|
| 707 |
+
self.parser.add_argument(
|
| 708 |
+
"--checkpoint.create_seed_checkpoint",
|
| 709 |
+
action="store_true",
|
| 710 |
+
help="""
|
| 711 |
+
Initializes the full model without applying parallelisms, and then saves it as a seed checkpoint.
|
| 712 |
+
Note: requires user to call train.py without specifying any parallelisms, e.g. NGPU=1.
|
| 713 |
+
Could be implemented as a separate script, but this way shares more code.
|
| 714 |
+
""",
|
| 715 |
+
)
|
| 716 |
+
self.parser.add_argument(
|
| 717 |
+
"--checkpoint.async_mode",
|
| 718 |
+
type=str,
|
| 719 |
+
default="disabled",
|
| 720 |
+
help="""
|
| 721 |
+
Which async checkpoint mode to use. Currently there are 3 different modes.
|
| 722 |
+
1. "disabled": synchronized checkpointing will be used.
|
| 723 |
+
2. "async": torch.distributed.checkpoint.async_save will be used.
|
| 724 |
+
3. "async_with_pinned_mem": this option utilizes a dedicated pinned memory
|
| 725 |
+
space and creates a separate process for faster GPU->CPU transfer
|
| 726 |
+
performance and eliminating GIL contention. The cost is increased CPU
|
| 727 |
+
memory usage. If insufficient CPU memory is available, performance may
|
| 728 |
+
degrade due to memory paging. For most users, "async" should suffice as
|
| 729 |
+
the performance overhead is typically small (on the order of tens of
|
| 730 |
+
seconds) compared to checkpointing frequency. This mode can be employed
|
| 731 |
+
to pursue near-zero checkpointing times (e.g., < 1 second) given
|
| 732 |
+
appropriate hardware support such as ample CPU memory and fast PCIe.
|
| 733 |
+
|
| 734 |
+
"disabled" is the default mode.
|
| 735 |
+
""",
|
| 736 |
+
)
|
| 737 |
+
self.parser.add_argument(
|
| 738 |
+
"--checkpoint.keep_latest_k",
|
| 739 |
+
type=int,
|
| 740 |
+
default=0,
|
| 741 |
+
help="""
|
| 742 |
+
Keeps only the latest k checkpoints, and purging older ones. If 0, keep all checkpoints.
|
| 743 |
+
0 is the default value. k cannot be 1 as the last one may be in the process of being
|
| 744 |
+
saved. As a result, the metadata of the last one may not be ready yet.
|
| 745 |
+
""",
|
| 746 |
+
)
|
| 747 |
+
self.parser.add_argument(
|
| 748 |
+
"--checkpoint.load_step",
|
| 749 |
+
type=int,
|
| 750 |
+
default=-1,
|
| 751 |
+
help="Load the checkpoint at the specified step. If -1, load the latest checkpoint.",
|
| 752 |
+
)
|
| 753 |
+
self.parser.add_argument(
|
| 754 |
+
"--checkpoint.exclude_from_loading",
|
| 755 |
+
type=string_list,
|
| 756 |
+
nargs="*",
|
| 757 |
+
default=[],
|
| 758 |
+
help="""
|
| 759 |
+
Exclude specific keys from being loaded from the checkpoint.
|
| 760 |
+
Provide a comma-separated list of keys to exclude, e.g. 'optimizer,lr_scheduler,dataloader'.
|
| 761 |
+
This will load the model only, excluding the specified keys.
|
| 762 |
+
""",
|
| 763 |
+
)
|
| 764 |
+
# activation checkpointing configs
|
| 765 |
+
self.parser.add_argument(
|
| 766 |
+
"--activation_checkpoint.mode",
|
| 767 |
+
type=str,
|
| 768 |
+
default="selective",
|
| 769 |
+
help="Type of activation checkpointing to use ['none', 'full', 'selective']",
|
| 770 |
+
)
|
| 771 |
+
self.parser.add_argument(
|
| 772 |
+
"--activation_checkpoint.selective_ac_option",
|
| 773 |
+
type=str,
|
| 774 |
+
default="2", # 2 = checkpoint every other layer
|
| 775 |
+
help="""
|
| 776 |
+
Selective activation checkpointing options ['int', 'op'].
|
| 777 |
+
'int' (e.g., 2) for every nth layer, or 'op' for op level ac.
|
| 778 |
+
""",
|
| 779 |
+
)
|
| 780 |
+
|
| 781 |
+
self.parser.add_argument(
|
| 782 |
+
"--activation_offload.mode",
|
| 783 |
+
type=str,
|
| 784 |
+
default="none",
|
| 785 |
+
help="""
|
| 786 |
+
if we are using activation offload or not. Options are ['none', 'full'].
|
| 787 |
+
""",
|
| 788 |
+
)
|
| 789 |
+
|
| 790 |
+
# float8 configs
|
| 791 |
+
self.parser.add_argument(
|
| 792 |
+
"--float8.enable_fsdp_float8_all_gather",
|
| 793 |
+
action="store_true",
|
| 794 |
+
help="Whether enable float8 all-gather in FSDP, recommended for tensorwise scaling",
|
| 795 |
+
)
|
| 796 |
+
self.parser.add_argument(
|
| 797 |
+
"--float8.precompute_float8_dynamic_scale_for_fsdp",
|
| 798 |
+
action="store_true",
|
| 799 |
+
help="Whether precompute float8 scales dynamically for FSDP, recommended for tensorwise scaling",
|
| 800 |
+
)
|
| 801 |
+
self.parser.add_argument(
|
| 802 |
+
"--float8.force_recompute_fp8_weight_in_bwd",
|
| 803 |
+
action="store_true",
|
| 804 |
+
help="""
|
| 805 |
+
Whether to force the recomputation of FP8 weights during backward pass.
|
| 806 |
+
When using FSDP with tensorwise scaling, it is recommended to enable
|
| 807 |
+
`force_recompute_fp8_weight_in_bwd` to prevent saving unsharded FP8 weights
|
| 808 |
+
for backward computation.
|
| 809 |
+
""",
|
| 810 |
+
)
|
| 811 |
+
self.parser.add_argument(
|
| 812 |
+
"--float8.recipe_name",
|
| 813 |
+
type=str,
|
| 814 |
+
default=None,
|
| 815 |
+
choices=["tensorwise", "rowwise", "rowwise_with_gw_hp"],
|
| 816 |
+
help="""
|
| 817 |
+
If specified, creates float8 config from recipe name, valid choices are
|
| 818 |
+
`tensorwise`, `rowwise` and `rowwise_with_gw_hp`.
|
| 819 |
+
""",
|
| 820 |
+
)
|
| 821 |
+
|
| 822 |
+
# communications library settings
|
| 823 |
+
self.parser.add_argument(
|
| 824 |
+
"--comm.init_timeout_seconds",
|
| 825 |
+
type=int,
|
| 826 |
+
default=300,
|
| 827 |
+
help="Timeout for communication operations, during initialization and first train step.",
|
| 828 |
+
)
|
| 829 |
+
self.parser.add_argument(
|
| 830 |
+
"--comm.train_timeout_seconds",
|
| 831 |
+
type=int,
|
| 832 |
+
default=100,
|
| 833 |
+
help=(
|
| 834 |
+
"Timeout for communication operations after the first train step -- "
|
| 835 |
+
"usually a tighter bound than during initialization."
|
| 836 |
+
),
|
| 837 |
+
)
|
| 838 |
+
self.parser.add_argument(
|
| 839 |
+
"--comm.trace_buf_size",
|
| 840 |
+
type=int,
|
| 841 |
+
default=20000,
|
| 842 |
+
help="Flight recorder ring buffer size, >0 means recording by default, 0 means disabled",
|
| 843 |
+
)
|
| 844 |
+
|
| 845 |
+
# memory estimation settings
|
| 846 |
+
self.parser.add_argument(
|
| 847 |
+
"--memory_estimation.enabled",
|
| 848 |
+
help="Whether to estimate memory usage for FSDP",
|
| 849 |
+
action="store_true",
|
| 850 |
+
)
|
| 851 |
+
|
| 852 |
+
self.parser.add_argument(
|
| 853 |
+
"--memory_estimation.disable_fake_mode",
|
| 854 |
+
help="Whether to estimate memory under FakeTensorMode",
|
| 855 |
+
action="store_true",
|
| 856 |
+
)
|
| 857 |
+
|
| 858 |
+
self.parser.add_argument(
|
| 859 |
+
"--fault_tolerance.enable",
|
| 860 |
+
action="store_true",
|
| 861 |
+
help="""
|
| 862 |
+
Enable TorchFT integration. When TorchFT is enabled, HSDP will be used.
|
| 863 |
+
And --fault_tolerance.data_parallel_replicate_degree should be 1 and
|
| 864 |
+
--fault_tolerance.group_size will be used to control the maximum
|
| 865 |
+
replicate group size as the replicate group size is dynamic.
|
| 866 |
+
|
| 867 |
+
Note that this is still an experimental feature.
|
| 868 |
+
""",
|
| 869 |
+
)
|
| 870 |
+
|
| 871 |
+
self.parser.add_argument(
|
| 872 |
+
"--fault_tolerance.replica_id",
|
| 873 |
+
type=int,
|
| 874 |
+
default=0,
|
| 875 |
+
help="The TorchFT replica ID of this run.",
|
| 876 |
+
)
|
| 877 |
+
|
| 878 |
+
self.parser.add_argument(
|
| 879 |
+
"--fault_tolerance.group_size",
|
| 880 |
+
type=int,
|
| 881 |
+
default=0,
|
| 882 |
+
help="""
|
| 883 |
+
The number of TorchFT replicate groups. This number will be used for
|
| 884 |
+
dataloader to split the dataset across the replicate groups and FSDP
|
| 885 |
+
dimension
|
| 886 |
+
""",
|
| 887 |
+
)
|
| 888 |
+
|
| 889 |
+
self.parser.add_argument(
|
| 890 |
+
"--fault_tolerance.min_replica_size",
|
| 891 |
+
type=int,
|
| 892 |
+
default=1,
|
| 893 |
+
help="The minimum number of FT replica for each step.",
|
| 894 |
+
)
|
| 895 |
+
|
| 896 |
+
def to_dict(self):
|
| 897 |
+
return self.args_dict
|
| 898 |
+
|
| 899 |
+
def parse_args(self, args_list: list = sys.argv[1:]):
|
| 900 |
+
args, cmd_args = self.parse_args_from_command_line(args_list)
|
| 901 |
+
config_file = getattr(args, "job.config_file", None)
|
| 902 |
+
# build up a two level dict
|
| 903 |
+
args_dict = self._args_to_two_level_dict(args)
|
| 904 |
+
if config_file is not None:
|
| 905 |
+
try:
|
| 906 |
+
with open(config_file, "rb") as f:
|
| 907 |
+
for k, v in tomllib.load(f).items():
|
| 908 |
+
# to prevent overwrite of non-specified keys
|
| 909 |
+
args_dict[k] |= v
|
| 910 |
+
except (FileNotFoundError, tomllib.TOMLDecodeError) as e:
|
| 911 |
+
logger.exception(
|
| 912 |
+
f"Error while loading the configuration file: {config_file}"
|
| 913 |
+
)
|
| 914 |
+
logger.exception(f"Error details: {str(e)}")
|
| 915 |
+
raise e
|
| 916 |
+
|
| 917 |
+
# Checking string-list arguments are properly split into a list
|
| 918 |
+
# if split-points came from 'args' (from cmd line) it would have already been parsed into a list by that parser
|
| 919 |
+
string_list_argnames = self._get_string_list_argument_names()
|
| 920 |
+
for n in string_list_argnames:
|
| 921 |
+
check_string_list_argument(args_dict, n)
|
| 922 |
+
|
| 923 |
+
# override args dict with cmd_args
|
| 924 |
+
cmd_args_dict = self._args_to_two_level_dict(cmd_args)
|
| 925 |
+
for section, section_args in cmd_args_dict.items():
|
| 926 |
+
for k, v in section_args.items():
|
| 927 |
+
args_dict[section][k] = v
|
| 928 |
+
|
| 929 |
+
self.args_dict = args_dict
|
| 930 |
+
|
| 931 |
+
for k, v in args_dict.items():
|
| 932 |
+
class_type = type(k.title(), (), v)
|
| 933 |
+
setattr(self, k, class_type())
|
| 934 |
+
self._validate_config()
|
| 935 |
+
|
| 936 |
+
def _args_to_two_level_dict(self, args: argparse.Namespace) -> defaultdict:
|
| 937 |
+
args_dict = defaultdict(defaultdict)
|
| 938 |
+
for k, v in vars(args).items():
|
| 939 |
+
first_level_key, second_level_key = k.split(".", 1)
|
| 940 |
+
args_dict[first_level_key][second_level_key] = v
|
| 941 |
+
return args_dict
|
| 942 |
+
|
| 943 |
+
def _validate_config(self) -> None:
|
| 944 |
+
# TODO: Add more mandatory validations
|
| 945 |
+
assert self.model.config
|
| 946 |
+
assert self.model.tokenizer_path
|
| 947 |
+
|
| 948 |
+
def _get_string_list_argument_names(self) -> list[str]:
|
| 949 |
+
"""Get the parser argument names of type `string_list`."""
|
| 950 |
+
string_list_args = [
|
| 951 |
+
v.dest for v in self.parser._actions if v.type is string_list
|
| 952 |
+
]
|
| 953 |
+
return string_list_args
|
| 954 |
+
|
| 955 |
+
def parse_args_from_command_line(
|
| 956 |
+
self, args_list
|
| 957 |
+
) -> Tuple[argparse.Namespace, argparse.Namespace]:
|
| 958 |
+
"""
|
| 959 |
+
Parse command line arguments and return the parsed args and the command line only args
|
| 960 |
+
"""
|
| 961 |
+
args = self.parser.parse_args(args_list)
|
| 962 |
+
string_list_argnames = set(self._get_string_list_argument_names())
|
| 963 |
+
|
| 964 |
+
# aux parser to parse the command line only args, with no defaults from main parser
|
| 965 |
+
aux_parser = argparse.ArgumentParser(argument_default=argparse.SUPPRESS)
|
| 966 |
+
for arg, val in vars(args).items():
|
| 967 |
+
if isinstance(val, bool):
|
| 968 |
+
aux_parser.add_argument(
|
| 969 |
+
"--" + arg, action="store_true" if val else "store_false"
|
| 970 |
+
)
|
| 971 |
+
elif arg in string_list_argnames:
|
| 972 |
+
# without this special case, type inference breaks here,
|
| 973 |
+
# since the inferred type is just 'list' and it ends up flattening
|
| 974 |
+
# e.g. from ["layers.0", "layers.1"] into ["l", "a", "y", "e", "r", "s", ".0", ...]
|
| 975 |
+
aux_parser.add_argument("--" + arg, type=string_list)
|
| 976 |
+
else:
|
| 977 |
+
aux_parser.add_argument("--" + arg, type=type(val))
|
| 978 |
+
|
| 979 |
+
cmd_args, _ = aux_parser.parse_known_args(args_list)
|
| 980 |
+
|
| 981 |
+
return args, cmd_args
|