diff --git a/README.md b/README.md new file mode 100644 index 0000000000000000000000000000000000000000..5aa1c73c643affd65fdd3f58623db853fc84d16e --- /dev/null +++ b/README.md @@ -0,0 +1,58 @@ +--- +base_model: Qwen/Qwen2.5-1.5B-Instruct +library_name: transformers +model_name: Qwen2.5-1.5B-Open-R1-GRPO +tags: +- generated_from_trainer +- trl +- sft +licence: license +--- + +# Model Card for Qwen2.5-1.5B-Open-R1-GRPO + +This model is a fine-tuned version of [Qwen/Qwen2.5-1.5B-Instruct](https://huggingface.co/Qwen/Qwen2.5-1.5B-Instruct). +It has been trained using [TRL](https://github.com/huggingface/trl). + +## Quick start + +```python +from transformers import pipeline + +question = "If you had a time machine, but could only go to the past or the future once and never return, which would you choose and why?" +generator = pipeline("text-generation", model="ItsMaxNorm/Qwen2.5-1.5B-Open-R1-GRPO", device="cuda") +output = generator([{"role": "user", "content": question}], max_new_tokens=128, return_full_text=False)[0] +print(output["generated_text"]) +``` + +## Training procedure + + + + +This model was trained with SFT. + +### Framework versions + +- TRL: 0.18.0 +- Transformers: 4.52.3 +- Pytorch: 2.6.0 +- Datasets: 3.6.0 +- Tokenizers: 0.21.1 + +## Citations + + + +Cite TRL as: + +```bibtex +@misc{vonwerra2022trl, + title = {{TRL: Transformer Reinforcement Learning}}, + author = {Leandro von Werra and Younes Belkada and Lewis Tunstall and Edward Beeching and Tristan Thrush and Nathan Lambert and Shengyi Huang and Kashif Rasul and Quentin Gallou{\'e}dec}, + year = 2020, + journal = {GitHub repository}, + publisher = {GitHub}, + howpublished = {\url{https://github.com/huggingface/trl}} +} +``` \ No newline at end of file diff --git a/all_results.json b/all_results.json new file mode 100644 index 0000000000000000000000000000000000000000..2586025d8c1a0779c7d44e4efee6daa055342286 --- /dev/null +++ b/all_results.json @@ -0,0 +1,8 @@ +{ + "total_flos": 83412022984704.0, + "train_loss": 0.6596692787564319, + "train_runtime": 4532.2478, + "train_samples": 93733, + "train_samples_per_second": 20.681, + "train_steps_per_second": 0.041 +} \ No newline at end of file diff --git a/generation_config.json b/generation_config.json new file mode 100644 index 0000000000000000000000000000000000000000..0adec16ef11115eb28822a475cb567cba49ae18a --- /dev/null +++ b/generation_config.json @@ -0,0 +1,11 @@ +{ + "bos_token_id": 151643, + "do_sample": true, + "eos_token_id": 151645, + "pad_token_id": 151643, + "repetition_penalty": 1.1, + "temperature": 0.7, + "top_k": 20, + "top_p": 0.8, + "transformers_version": "4.52.3" +} diff --git a/global_step183/bf16_zero_pp_rank_0_mp_rank_00_optim_states.pt b/global_step183/bf16_zero_pp_rank_0_mp_rank_00_optim_states.pt new file mode 100644 index 0000000000000000000000000000000000000000..5e0be9380c80ab5b4093701286bd7556b03aecba --- /dev/null +++ b/global_step183/bf16_zero_pp_rank_0_mp_rank_00_optim_states.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3136915d265417e86d5f3a53479794b78e1611ef75150b9be757ccd72f9da3d9 +size 578897648 diff --git a/global_step183/bf16_zero_pp_rank_10_mp_rank_00_optim_states.pt b/global_step183/bf16_zero_pp_rank_10_mp_rank_00_optim_states.pt new file mode 100644 index 0000000000000000000000000000000000000000..0932c5e73b142626370973b6b38d347b18702ac5 --- /dev/null +++ b/global_step183/bf16_zero_pp_rank_10_mp_rank_00_optim_states.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ac4208b1fd413ad21a234f9364e934e3e566e2797a312c88bad7f0fbd252f879 +size 578897660 diff --git a/global_step183/bf16_zero_pp_rank_11_mp_rank_00_optim_states.pt b/global_step183/bf16_zero_pp_rank_11_mp_rank_00_optim_states.pt new file mode 100644 index 0000000000000000000000000000000000000000..20aead5b920a514a89e0dc99b44c87fc5e09ce7a --- /dev/null +++ b/global_step183/bf16_zero_pp_rank_11_mp_rank_00_optim_states.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3df41dd42bf5269814d08d6e2262d766fc56e8cfcf7a770d84ad731edd0a56a8 +size 578897660 diff --git a/global_step183/bf16_zero_pp_rank_12_mp_rank_00_optim_states.pt b/global_step183/bf16_zero_pp_rank_12_mp_rank_00_optim_states.pt new file mode 100644 index 0000000000000000000000000000000000000000..ede93b2a963d65a950d4cdf786ddf614de86b977 --- /dev/null +++ b/global_step183/bf16_zero_pp_rank_12_mp_rank_00_optim_states.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f81ba2b0dc32ca86d0d8966a02126d18cf380a4cda06cfd5b49f3435a77ab414 +size 578897660 diff --git a/global_step183/bf16_zero_pp_rank_13_mp_rank_00_optim_states.pt b/global_step183/bf16_zero_pp_rank_13_mp_rank_00_optim_states.pt new file mode 100644 index 0000000000000000000000000000000000000000..3c88c7b567a54a0ac9b1d0dabdfbb0ba6145f39c --- /dev/null +++ b/global_step183/bf16_zero_pp_rank_13_mp_rank_00_optim_states.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e80f23fc0febeee6a133036ccf84c7f2037bfbc1865aedf3b3b142594c7f2d3c +size 578897660 diff --git a/global_step183/bf16_zero_pp_rank_14_mp_rank_00_optim_states.pt b/global_step183/bf16_zero_pp_rank_14_mp_rank_00_optim_states.pt new file mode 100644 index 0000000000000000000000000000000000000000..2b2bbe4108fe0f4c23222b0e01e27fb7b8868154 --- /dev/null +++ b/global_step183/bf16_zero_pp_rank_14_mp_rank_00_optim_states.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3ff27eb6d57e1de875144573c035f12d64e2f0fb8240e4bf1062c129c6aae0f2 +size 578897660 diff --git a/global_step183/bf16_zero_pp_rank_15_mp_rank_00_optim_states.pt b/global_step183/bf16_zero_pp_rank_15_mp_rank_00_optim_states.pt new file mode 100644 index 0000000000000000000000000000000000000000..bf00be1a059dca48d5cdffc4491fbab0c2a9c65c --- /dev/null +++ b/global_step183/bf16_zero_pp_rank_15_mp_rank_00_optim_states.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2846765a651335d691df088db4983ad2290668b23979b881c89607c445b0b9ec +size 578897660 diff --git a/global_step183/bf16_zero_pp_rank_16_mp_rank_00_optim_states.pt b/global_step183/bf16_zero_pp_rank_16_mp_rank_00_optim_states.pt new file mode 100644 index 0000000000000000000000000000000000000000..584cb5d92b73d3625887bbb9e422d598afdeb4e9 --- /dev/null +++ b/global_step183/bf16_zero_pp_rank_16_mp_rank_00_optim_states.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:74e5a8a7ed7017c3575b2a14ffe0cfc69acfda1f39d3da52e67263a375cd8c23 +size 578897660 diff --git a/global_step183/bf16_zero_pp_rank_17_mp_rank_00_optim_states.pt b/global_step183/bf16_zero_pp_rank_17_mp_rank_00_optim_states.pt new file mode 100644 index 0000000000000000000000000000000000000000..056854e0c7035d95165c3fef35d4e4da1ef585e4 --- /dev/null +++ b/global_step183/bf16_zero_pp_rank_17_mp_rank_00_optim_states.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9a0b16bdad2aceb62d832a915e51155cc51dba490593e2096bac6b33ec3ac992 +size 578897660 diff --git a/global_step183/bf16_zero_pp_rank_18_mp_rank_00_optim_states.pt b/global_step183/bf16_zero_pp_rank_18_mp_rank_00_optim_states.pt new file mode 100644 index 0000000000000000000000000000000000000000..cac4c7c2b6ae1d56e72d7f9daa4e7fba3137e4de --- /dev/null +++ b/global_step183/bf16_zero_pp_rank_18_mp_rank_00_optim_states.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b54582ab3fc720d5f9e197b418915938be1d94a5bfd59fbcbee3e1cff76afb3b +size 578897660 diff --git a/global_step183/bf16_zero_pp_rank_19_mp_rank_00_optim_states.pt b/global_step183/bf16_zero_pp_rank_19_mp_rank_00_optim_states.pt new file mode 100644 index 0000000000000000000000000000000000000000..bcf52014f43667374e810b8f6a73f81a9497494f --- /dev/null +++ b/global_step183/bf16_zero_pp_rank_19_mp_rank_00_optim_states.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9aa3d67fe87523645f22fc906f45025d29349994bfd18910b61b641a20d844f0 +size 578897660 diff --git a/global_step183/bf16_zero_pp_rank_1_mp_rank_00_optim_states.pt b/global_step183/bf16_zero_pp_rank_1_mp_rank_00_optim_states.pt new file mode 100644 index 0000000000000000000000000000000000000000..6398eec3814ba6d246ba7900e0092ad2f04b5103 --- /dev/null +++ b/global_step183/bf16_zero_pp_rank_1_mp_rank_00_optim_states.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6e5a3188fd584e0a784ccb95bd02afe975707cc440f86358838477119b444785 +size 578897648 diff --git a/global_step183/bf16_zero_pp_rank_20_mp_rank_00_optim_states.pt b/global_step183/bf16_zero_pp_rank_20_mp_rank_00_optim_states.pt new file mode 100644 index 0000000000000000000000000000000000000000..7020b89d3380aee6b389c8549b33450e6e82cbc4 --- /dev/null +++ b/global_step183/bf16_zero_pp_rank_20_mp_rank_00_optim_states.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4991f971045b154303336850f3bd0c69f8bf3022e5fc9b4dda1dbab2189b0f94 +size 578897660 diff --git a/global_step183/bf16_zero_pp_rank_21_mp_rank_00_optim_states.pt b/global_step183/bf16_zero_pp_rank_21_mp_rank_00_optim_states.pt new file mode 100644 index 0000000000000000000000000000000000000000..4072ee3400d301f693b278f3e0c8fe47482b15d5 --- /dev/null +++ b/global_step183/bf16_zero_pp_rank_21_mp_rank_00_optim_states.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c69aed0cfdfa77c20c5f80a31237212c75896da7c42aa4ade2fb56420b90276c +size 578897660 diff --git a/global_step183/bf16_zero_pp_rank_22_mp_rank_00_optim_states.pt b/global_step183/bf16_zero_pp_rank_22_mp_rank_00_optim_states.pt new file mode 100644 index 0000000000000000000000000000000000000000..71fe9a9d86b3a49e7c3f5278bc2eeff24f3c2308 --- /dev/null +++ b/global_step183/bf16_zero_pp_rank_22_mp_rank_00_optim_states.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:081751f7588a2ee7d27ac676dfb646f49765c45f48c0a127f6bbcf3aef4b8030 +size 578897660 diff --git a/global_step183/bf16_zero_pp_rank_23_mp_rank_00_optim_states.pt b/global_step183/bf16_zero_pp_rank_23_mp_rank_00_optim_states.pt new file mode 100644 index 0000000000000000000000000000000000000000..b57af241ef0f956bf00224f09362dec9a19be14a --- /dev/null +++ b/global_step183/bf16_zero_pp_rank_23_mp_rank_00_optim_states.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5a6ffaab4764ead334625f477deb7ccde46ec140849d9e1b7d05cef8a5efa733 +size 578897660 diff --git a/global_step183/bf16_zero_pp_rank_24_mp_rank_00_optim_states.pt b/global_step183/bf16_zero_pp_rank_24_mp_rank_00_optim_states.pt new file mode 100644 index 0000000000000000000000000000000000000000..8cfbf8a90dea86da275c1d4aee7d32c996ee181f --- /dev/null +++ b/global_step183/bf16_zero_pp_rank_24_mp_rank_00_optim_states.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b95f20a5def82b80476b7d8482ab7a0d889c9e1f691c34003084670087dc6a86 +size 578897660 diff --git a/global_step183/bf16_zero_pp_rank_25_mp_rank_00_optim_states.pt b/global_step183/bf16_zero_pp_rank_25_mp_rank_00_optim_states.pt new file mode 100644 index 0000000000000000000000000000000000000000..6343aef4b10021d9c496f57359e69e87b9667d2c --- /dev/null +++ b/global_step183/bf16_zero_pp_rank_25_mp_rank_00_optim_states.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d6b712bb4678c5387fda9e2bc1a2851bf9707e061872d841dc147e166eebc610 +size 578897660 diff --git a/global_step183/bf16_zero_pp_rank_26_mp_rank_00_optim_states.pt b/global_step183/bf16_zero_pp_rank_26_mp_rank_00_optim_states.pt new file mode 100644 index 0000000000000000000000000000000000000000..7850669eda17c3f041fb0ed71f626d31e3da0dff --- /dev/null +++ b/global_step183/bf16_zero_pp_rank_26_mp_rank_00_optim_states.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:79fc7f526f10bbfc2205b8c4cb9f7b9a72638ab24cfe68fb4b234ade9ff8b9e6 +size 578897660 diff --git a/global_step183/bf16_zero_pp_rank_27_mp_rank_00_optim_states.pt b/global_step183/bf16_zero_pp_rank_27_mp_rank_00_optim_states.pt new file mode 100644 index 0000000000000000000000000000000000000000..dbcee1f29e5c26329f3ed64004268bf658313bad --- /dev/null +++ b/global_step183/bf16_zero_pp_rank_27_mp_rank_00_optim_states.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0d70a0bf11a120f5ddb79c0a053bf508d5df8131309b3c048c0af70711252c7d +size 578897660 diff --git a/global_step183/bf16_zero_pp_rank_28_mp_rank_00_optim_states.pt b/global_step183/bf16_zero_pp_rank_28_mp_rank_00_optim_states.pt new file mode 100644 index 0000000000000000000000000000000000000000..486a21d06efc6ba3b513c5284848bd122833b8d5 --- /dev/null +++ b/global_step183/bf16_zero_pp_rank_28_mp_rank_00_optim_states.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0f68a1f8f9f8baa7438c06a2c4c1a99abad08f46d4a9a234f1ffcb26964c4653 +size 578897660 diff --git a/global_step183/bf16_zero_pp_rank_29_mp_rank_00_optim_states.pt b/global_step183/bf16_zero_pp_rank_29_mp_rank_00_optim_states.pt new file mode 100644 index 0000000000000000000000000000000000000000..c50eaf330917a4010e867d8da378c3a3d57c5fb9 --- /dev/null +++ b/global_step183/bf16_zero_pp_rank_29_mp_rank_00_optim_states.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c1ea9e8faf3400988f61b30e835d9b8ae577be2296a54fde487428d57f23e0f0 +size 578897660 diff --git a/global_step183/bf16_zero_pp_rank_2_mp_rank_00_optim_states.pt b/global_step183/bf16_zero_pp_rank_2_mp_rank_00_optim_states.pt new file mode 100644 index 0000000000000000000000000000000000000000..5063d5fe606d33f37432990747047872efc3bcd5 --- /dev/null +++ b/global_step183/bf16_zero_pp_rank_2_mp_rank_00_optim_states.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:83261b7459bf5cb0a5269dba502c5dc496ef78893931749c2edd9f3278bccbae +size 578897648 diff --git a/global_step183/bf16_zero_pp_rank_30_mp_rank_00_optim_states.pt b/global_step183/bf16_zero_pp_rank_30_mp_rank_00_optim_states.pt new file mode 100644 index 0000000000000000000000000000000000000000..ba74582b4666289f11ec97511041d6778da5eba6 --- /dev/null +++ b/global_step183/bf16_zero_pp_rank_30_mp_rank_00_optim_states.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d8fd934a8a07476de918eb01b8d78f278f1a052cdc9668638aaead080dad5cea +size 578897660 diff --git a/global_step183/bf16_zero_pp_rank_31_mp_rank_00_optim_states.pt b/global_step183/bf16_zero_pp_rank_31_mp_rank_00_optim_states.pt new file mode 100644 index 0000000000000000000000000000000000000000..d86cb62e2444de18fb2fd30cd837d6a2067615bf --- /dev/null +++ b/global_step183/bf16_zero_pp_rank_31_mp_rank_00_optim_states.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:cd162bcdc4b2f7ce575db9b387650f4f3d86a946fe93f2085f3d08aa75ab5a07 +size 578897660 diff --git a/global_step183/bf16_zero_pp_rank_3_mp_rank_00_optim_states.pt b/global_step183/bf16_zero_pp_rank_3_mp_rank_00_optim_states.pt new file mode 100644 index 0000000000000000000000000000000000000000..7a5d852e006d18b1b755f480c1bfbeb36c61ddee --- /dev/null +++ b/global_step183/bf16_zero_pp_rank_3_mp_rank_00_optim_states.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f9057f5ef23383e7ef3c0b1b333d59281b37e7191b142d53b4b1f2d692a3502a +size 578897648 diff --git a/global_step183/bf16_zero_pp_rank_4_mp_rank_00_optim_states.pt b/global_step183/bf16_zero_pp_rank_4_mp_rank_00_optim_states.pt new file mode 100644 index 0000000000000000000000000000000000000000..5f5702e7d8af464bd5ccd62c4db279fb3a88842e --- /dev/null +++ b/global_step183/bf16_zero_pp_rank_4_mp_rank_00_optim_states.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6b0d99308ad831ff64bd98c098ad50a554765564630e5dc7db1615ca9fb04063 +size 578897648 diff --git a/global_step183/bf16_zero_pp_rank_5_mp_rank_00_optim_states.pt b/global_step183/bf16_zero_pp_rank_5_mp_rank_00_optim_states.pt new file mode 100644 index 0000000000000000000000000000000000000000..2cbca49b672ed636b2f09471d4cec88b5e2990e0 --- /dev/null +++ b/global_step183/bf16_zero_pp_rank_5_mp_rank_00_optim_states.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:009a81d85e71ceeef03e45e30026616317071d0a40faf7710defb7df7d5e6bef +size 578897648 diff --git a/global_step183/bf16_zero_pp_rank_6_mp_rank_00_optim_states.pt b/global_step183/bf16_zero_pp_rank_6_mp_rank_00_optim_states.pt new file mode 100644 index 0000000000000000000000000000000000000000..ffd3e10b52c43ff1c9dab2e72305498b8fe25985 --- /dev/null +++ b/global_step183/bf16_zero_pp_rank_6_mp_rank_00_optim_states.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:47fa3ed6c92f797c7afc5f17076e6e8120fd6a4251ec166a6269fd8a1d2ff946 +size 578897648 diff --git a/global_step183/bf16_zero_pp_rank_7_mp_rank_00_optim_states.pt b/global_step183/bf16_zero_pp_rank_7_mp_rank_00_optim_states.pt new file mode 100644 index 0000000000000000000000000000000000000000..1f55c5a1f8dc5371146f1f681e26bbe622e63994 --- /dev/null +++ b/global_step183/bf16_zero_pp_rank_7_mp_rank_00_optim_states.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:19e168ec2e9dac12e7f7f27d2fb51d922d1069730e7181980482dbf83745c97e +size 578897648 diff --git a/global_step183/bf16_zero_pp_rank_8_mp_rank_00_optim_states.pt b/global_step183/bf16_zero_pp_rank_8_mp_rank_00_optim_states.pt new file mode 100644 index 0000000000000000000000000000000000000000..c7207f4d75994c7578aac556249a1e3c7fb0e77d --- /dev/null +++ b/global_step183/bf16_zero_pp_rank_8_mp_rank_00_optim_states.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a6e219d7151349efd1350e308916b8cb78e686b0351e3fd5f5826297dbba93f0 +size 578897648 diff --git a/global_step183/bf16_zero_pp_rank_9_mp_rank_00_optim_states.pt b/global_step183/bf16_zero_pp_rank_9_mp_rank_00_optim_states.pt new file mode 100644 index 0000000000000000000000000000000000000000..53cb751d396e370729e677d7cebdedc6629889ea --- /dev/null +++ b/global_step183/bf16_zero_pp_rank_9_mp_rank_00_optim_states.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1c4a951d9fea7c80faa847a7b2f5fbbf43a5dfbd5791e19cb114a2e5083252e6 +size 578897648 diff --git a/global_step183/zero_pp_rank_0_mp_rank_00_model_states.pt b/global_step183/zero_pp_rank_0_mp_rank_00_model_states.pt new file mode 100644 index 0000000000000000000000000000000000000000..ca2a8dcb5f57e9bb2c94fdcc2a5da35257f7cded --- /dev/null +++ b/global_step183/zero_pp_rank_0_mp_rank_00_model_states.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2a910410c9aca0c5880b0d78172e3abef0ea00af9404da1e132173e7c8833337 +size 166072 diff --git a/global_step183/zero_pp_rank_10_mp_rank_00_model_states.pt b/global_step183/zero_pp_rank_10_mp_rank_00_model_states.pt new file mode 100644 index 0000000000000000000000000000000000000000..6e02d3dbbf97c719dc7bc08a631c39a781c2d873 --- /dev/null +++ b/global_step183/zero_pp_rank_10_mp_rank_00_model_states.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:eb9e30f10d95a4d3aee9f84c59d3bcdda419f29ffbe756e1fc11183e9131cc9e +size 166350 diff --git a/global_step183/zero_pp_rank_11_mp_rank_00_model_states.pt b/global_step183/zero_pp_rank_11_mp_rank_00_model_states.pt new file mode 100644 index 0000000000000000000000000000000000000000..73485634d0b3e1dee79380d8bab15fb6b24c934b --- /dev/null +++ b/global_step183/zero_pp_rank_11_mp_rank_00_model_states.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d51ce6b8792cf134e8af9aadd9deda4fcdddff37b824854af5bef2c16841e9b1 +size 166350 diff --git a/global_step183/zero_pp_rank_12_mp_rank_00_model_states.pt b/global_step183/zero_pp_rank_12_mp_rank_00_model_states.pt new file mode 100644 index 0000000000000000000000000000000000000000..2f015cdc839939697634bef8c5fd43e33af206b9 --- /dev/null +++ b/global_step183/zero_pp_rank_12_mp_rank_00_model_states.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2bb1d0358e7bff3f8fd40da973c49377c778efb410c33cbc6740a9375f0e6832 +size 166350 diff --git a/global_step183/zero_pp_rank_13_mp_rank_00_model_states.pt b/global_step183/zero_pp_rank_13_mp_rank_00_model_states.pt new file mode 100644 index 0000000000000000000000000000000000000000..35d46ce3acbfca8b68b5ebecc98dd4670caa02c1 --- /dev/null +++ b/global_step183/zero_pp_rank_13_mp_rank_00_model_states.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a80a6e5052e3c7cea9708c46eb407cd1150e584ad87d7ac793288a24da403e8c +size 166350 diff --git a/global_step183/zero_pp_rank_14_mp_rank_00_model_states.pt b/global_step183/zero_pp_rank_14_mp_rank_00_model_states.pt new file mode 100644 index 0000000000000000000000000000000000000000..b472a84803b3a4cb4c489000c367c0c6f960b99e --- /dev/null +++ b/global_step183/zero_pp_rank_14_mp_rank_00_model_states.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3db23914c4d585f7e86179964719294dce335ec92232956a7f3f318a761c6aa7 +size 166350 diff --git a/global_step183/zero_pp_rank_15_mp_rank_00_model_states.pt b/global_step183/zero_pp_rank_15_mp_rank_00_model_states.pt new file mode 100644 index 0000000000000000000000000000000000000000..78b128b6acd60e43870765e8d513e2a68f0373b3 --- /dev/null +++ b/global_step183/zero_pp_rank_15_mp_rank_00_model_states.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2854c3a15706e6e635ef18163a1e61885f4901516ea3b23298030323bfa3ca7a +size 166350 diff --git a/global_step183/zero_pp_rank_16_mp_rank_00_model_states.pt b/global_step183/zero_pp_rank_16_mp_rank_00_model_states.pt new file mode 100644 index 0000000000000000000000000000000000000000..b02da93fe525f1e490912318fc113aaa6b0c70cd --- /dev/null +++ b/global_step183/zero_pp_rank_16_mp_rank_00_model_states.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4df1b861930d4f1ec7af220cf7f2d17b6cde202748dc7ba3aebc53ceb104ae2d +size 166350 diff --git a/global_step183/zero_pp_rank_17_mp_rank_00_model_states.pt b/global_step183/zero_pp_rank_17_mp_rank_00_model_states.pt new file mode 100644 index 0000000000000000000000000000000000000000..319e66b2b62edb07cf53cfa22353cb82c9521e64 --- /dev/null +++ b/global_step183/zero_pp_rank_17_mp_rank_00_model_states.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:df055f3212e0218a0bf0c2485b663def38f23c72ab96641af1ceeb71fa52c0b5 +size 166350 diff --git a/global_step183/zero_pp_rank_18_mp_rank_00_model_states.pt b/global_step183/zero_pp_rank_18_mp_rank_00_model_states.pt new file mode 100644 index 0000000000000000000000000000000000000000..2748c89b5a7e4f8c01f4946b8059ca1aa874a91d --- /dev/null +++ b/global_step183/zero_pp_rank_18_mp_rank_00_model_states.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b193626e45fed30b6e5f2730032aaab88d23cd1dde24bba608631e48236a4bf0 +size 166350 diff --git a/global_step183/zero_pp_rank_19_mp_rank_00_model_states.pt b/global_step183/zero_pp_rank_19_mp_rank_00_model_states.pt new file mode 100644 index 0000000000000000000000000000000000000000..e305eff7b2b5db78d54255a416e079e129ec9163 --- /dev/null +++ b/global_step183/zero_pp_rank_19_mp_rank_00_model_states.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2149c88768003b007c3a9f8e386df97db64329b3de7c07db2078c724788cf4a4 +size 166350 diff --git a/global_step183/zero_pp_rank_1_mp_rank_00_model_states.pt b/global_step183/zero_pp_rank_1_mp_rank_00_model_states.pt new file mode 100644 index 0000000000000000000000000000000000000000..a6246f840bb0590f9cff7262b7896273cf82db0b --- /dev/null +++ b/global_step183/zero_pp_rank_1_mp_rank_00_model_states.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7f64a010320812fd70544c6781bd885cdf57785c87492ee410d92e7860321d3e +size 166008 diff --git a/global_step183/zero_pp_rank_20_mp_rank_00_model_states.pt b/global_step183/zero_pp_rank_20_mp_rank_00_model_states.pt new file mode 100644 index 0000000000000000000000000000000000000000..b18d0ff7fea96d1fac8350e0ed78bac8f101fb2c --- /dev/null +++ b/global_step183/zero_pp_rank_20_mp_rank_00_model_states.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1200e3e3c247d8497b1d5e1ffa5ad79a71b350c13552d83764ed1503a45d605c +size 166350 diff --git a/global_step183/zero_pp_rank_21_mp_rank_00_model_states.pt b/global_step183/zero_pp_rank_21_mp_rank_00_model_states.pt new file mode 100644 index 0000000000000000000000000000000000000000..98d916762661e63c50d1634da353110ba118079f --- /dev/null +++ b/global_step183/zero_pp_rank_21_mp_rank_00_model_states.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1782a91dc05abdb7186321180adb541e32d9af0616220200107e66a2c7e86a40 +size 166350 diff --git a/global_step183/zero_pp_rank_22_mp_rank_00_model_states.pt b/global_step183/zero_pp_rank_22_mp_rank_00_model_states.pt new file mode 100644 index 0000000000000000000000000000000000000000..51c26c79d258b69d9fc25e3c1a235292bff8aeee --- /dev/null +++ b/global_step183/zero_pp_rank_22_mp_rank_00_model_states.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:79d7e7618a8ff694f2f43801c09ae0f150a794d98f065b7789fa1023cbf46416 +size 166350 diff --git a/global_step183/zero_pp_rank_23_mp_rank_00_model_states.pt b/global_step183/zero_pp_rank_23_mp_rank_00_model_states.pt new file mode 100644 index 0000000000000000000000000000000000000000..1fbfee9252fd5c2c87574c7091c6c1486a9e7b0a --- /dev/null +++ b/global_step183/zero_pp_rank_23_mp_rank_00_model_states.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d7a30d777789b3d05da70581c0668f43f586045a2c12e7287c566710378d0898 +size 166350 diff --git a/global_step183/zero_pp_rank_24_mp_rank_00_model_states.pt b/global_step183/zero_pp_rank_24_mp_rank_00_model_states.pt new file mode 100644 index 0000000000000000000000000000000000000000..8a1911749f9d37b00c1748fb9daab50773125e95 --- /dev/null +++ b/global_step183/zero_pp_rank_24_mp_rank_00_model_states.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9ff4ab8d4f7f51189c294ec1be8a18b3388923dbd63675d20f5d2ecb5ccddbdb +size 166350 diff --git a/global_step183/zero_pp_rank_25_mp_rank_00_model_states.pt b/global_step183/zero_pp_rank_25_mp_rank_00_model_states.pt new file mode 100644 index 0000000000000000000000000000000000000000..a041c3b2b76c67cfd85c5fc7c00908ed22611dae --- /dev/null +++ b/global_step183/zero_pp_rank_25_mp_rank_00_model_states.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:eded50d611ecc6c7dee748c9bb8305d8808f2390e772f58dac62c25202cf37d8 +size 166350 diff --git a/global_step183/zero_pp_rank_26_mp_rank_00_model_states.pt b/global_step183/zero_pp_rank_26_mp_rank_00_model_states.pt new file mode 100644 index 0000000000000000000000000000000000000000..d8c4c3e1e45dd43ed4dda0dca50f871fee499efa --- /dev/null +++ b/global_step183/zero_pp_rank_26_mp_rank_00_model_states.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1e5e08f0c9c733aaeb06be61a8fa4b8105b7c23233ec20e5e4b09b09dbf76101 +size 166350 diff --git a/global_step183/zero_pp_rank_27_mp_rank_00_model_states.pt b/global_step183/zero_pp_rank_27_mp_rank_00_model_states.pt new file mode 100644 index 0000000000000000000000000000000000000000..a56f151c9c808a7d1e3413e1e5c173435a57852b --- /dev/null +++ b/global_step183/zero_pp_rank_27_mp_rank_00_model_states.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4075b475a177114a57cfb538a79267ed96544b5eb4b1be660b71f610611d6e4b +size 166350 diff --git a/global_step183/zero_pp_rank_28_mp_rank_00_model_states.pt b/global_step183/zero_pp_rank_28_mp_rank_00_model_states.pt new file mode 100644 index 0000000000000000000000000000000000000000..ba1ee23b38f377fa9028517b7886e8b205d1c5f9 --- /dev/null +++ b/global_step183/zero_pp_rank_28_mp_rank_00_model_states.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8a119557e7e70b20e4370add676b753ee23b7c9c6c2da26fee20f21b1d6c440c +size 166350 diff --git a/global_step183/zero_pp_rank_29_mp_rank_00_model_states.pt b/global_step183/zero_pp_rank_29_mp_rank_00_model_states.pt new file mode 100644 index 0000000000000000000000000000000000000000..7b4bd65c7002cb006513893443812bc17287779d --- /dev/null +++ b/global_step183/zero_pp_rank_29_mp_rank_00_model_states.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:860a4c6c57569d5106c6e3044112f3abd912291168f4a6ca15912eb9e7b398e4 +size 166350 diff --git a/global_step183/zero_pp_rank_2_mp_rank_00_model_states.pt b/global_step183/zero_pp_rank_2_mp_rank_00_model_states.pt new file mode 100644 index 0000000000000000000000000000000000000000..2fc92360a6a4f4fb0bf879d01c0c309dc9320627 --- /dev/null +++ b/global_step183/zero_pp_rank_2_mp_rank_00_model_states.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:dc1b0f5f36682f95898c9888514b40f14b281b7b3be7535625bd26261ce5f6c7 +size 166008 diff --git a/global_step183/zero_pp_rank_30_mp_rank_00_model_states.pt b/global_step183/zero_pp_rank_30_mp_rank_00_model_states.pt new file mode 100644 index 0000000000000000000000000000000000000000..eb509294792c7d0cdfd4f827b6646a994c53e7a8 --- /dev/null +++ b/global_step183/zero_pp_rank_30_mp_rank_00_model_states.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d2667804a08345ee02a7a44fd8e9f8b4012e11c9c693a9d8097f0aae93678319 +size 166350 diff --git a/global_step183/zero_pp_rank_31_mp_rank_00_model_states.pt b/global_step183/zero_pp_rank_31_mp_rank_00_model_states.pt new file mode 100644 index 0000000000000000000000000000000000000000..18fa84d3b261f7a596ef5bab51cdeb56ece85dd2 --- /dev/null +++ b/global_step183/zero_pp_rank_31_mp_rank_00_model_states.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5fdecc5bedeb8695bc2f96f098fbcc6f19ee8a759028dd5cc1287f606ba6767c +size 166350 diff --git a/global_step183/zero_pp_rank_3_mp_rank_00_model_states.pt b/global_step183/zero_pp_rank_3_mp_rank_00_model_states.pt new file mode 100644 index 0000000000000000000000000000000000000000..03a2e008dab044fd5d44f4ec3c07395c111311f8 --- /dev/null +++ b/global_step183/zero_pp_rank_3_mp_rank_00_model_states.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8cb0b424582e9b89c7c60c75d96da9b80daee451c82c7fd2e6ab7f7298639a51 +size 166008 diff --git a/global_step183/zero_pp_rank_4_mp_rank_00_model_states.pt b/global_step183/zero_pp_rank_4_mp_rank_00_model_states.pt new file mode 100644 index 0000000000000000000000000000000000000000..fe1cfb7fd972415cf43fcbfcd9366e9c3003707a --- /dev/null +++ b/global_step183/zero_pp_rank_4_mp_rank_00_model_states.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c3440f24e40e1a37216f58ccf5ead1445c0d3ce1085b543d8281e74f872bbad3 +size 166008 diff --git a/global_step183/zero_pp_rank_5_mp_rank_00_model_states.pt b/global_step183/zero_pp_rank_5_mp_rank_00_model_states.pt new file mode 100644 index 0000000000000000000000000000000000000000..f8a7df407379018e2394d8ed4b43de4fde28aab0 --- /dev/null +++ b/global_step183/zero_pp_rank_5_mp_rank_00_model_states.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:997c8b1f90770b25973323cec182aac71a8d590c0c9bdf0081682eed6fcca115 +size 166008 diff --git a/global_step183/zero_pp_rank_6_mp_rank_00_model_states.pt b/global_step183/zero_pp_rank_6_mp_rank_00_model_states.pt new file mode 100644 index 0000000000000000000000000000000000000000..d65766918d5651b1edfb776d60d0ef38286f7908 --- /dev/null +++ b/global_step183/zero_pp_rank_6_mp_rank_00_model_states.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f6438d3c2077350d977e4316d00bfb84486fda1890706c8e4b4042b6c4dfc9ae +size 166008 diff --git a/global_step183/zero_pp_rank_7_mp_rank_00_model_states.pt b/global_step183/zero_pp_rank_7_mp_rank_00_model_states.pt new file mode 100644 index 0000000000000000000000000000000000000000..f86e4bfd005ddc9334fd543ef00d53871bccb742 --- /dev/null +++ b/global_step183/zero_pp_rank_7_mp_rank_00_model_states.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2cca200c903622f558e40be418263997541b984855ae4556d530b2685dc4fd5f +size 166008 diff --git a/global_step183/zero_pp_rank_8_mp_rank_00_model_states.pt b/global_step183/zero_pp_rank_8_mp_rank_00_model_states.pt new file mode 100644 index 0000000000000000000000000000000000000000..60a86a08e3871516bf5f88010c55bf485b3efd2e --- /dev/null +++ b/global_step183/zero_pp_rank_8_mp_rank_00_model_states.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:aec3d92c0899c2cb29f6cbd09ca04403b62d69bccba7aea69f6fb66c249e1826 +size 166008 diff --git a/global_step183/zero_pp_rank_9_mp_rank_00_model_states.pt b/global_step183/zero_pp_rank_9_mp_rank_00_model_states.pt new file mode 100644 index 0000000000000000000000000000000000000000..d22227c82bd069b8ce1c9019a8688870ff66d8aa --- /dev/null +++ b/global_step183/zero_pp_rank_9_mp_rank_00_model_states.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:518f79cd85185851a2ba6c87dac2a875c4cb1ab8e680ef2bdc96099172410f8c +size 166008 diff --git a/latest b/latest new file mode 100644 index 0000000000000000000000000000000000000000..38110fde80f49967b0573bb666889f91dbbb7d2f --- /dev/null +++ b/latest @@ -0,0 +1 @@ +global_step183 \ No newline at end of file diff --git a/train_results.json b/train_results.json new file mode 100644 index 0000000000000000000000000000000000000000..2586025d8c1a0779c7d44e4efee6daa055342286 --- /dev/null +++ b/train_results.json @@ -0,0 +1,8 @@ +{ + "total_flos": 83412022984704.0, + "train_loss": 0.6596692787564319, + "train_runtime": 4532.2478, + "train_samples": 93733, + "train_samples_per_second": 20.681, + "train_steps_per_second": 0.041 +} \ No newline at end of file diff --git a/trainer_state.json b/trainer_state.json new file mode 100644 index 0000000000000000000000000000000000000000..779a506a5808f936849eb901e983bd88158c2d22 --- /dev/null +++ b/trainer_state.json @@ -0,0 +1,1699 @@ +{ + "best_global_step": null, + "best_metric": null, + "best_model_checkpoint": null, + "epoch": 1.0, + "eval_steps": 500, + "global_step": 184, + "is_hyper_param_search": false, + "is_local_process_zero": true, + "is_world_process_zero": true, + "log_history": [ + { + "epoch": 0.005457025920873124, + "grad_norm": 10.22213420182019, + "learning_rate": 0.0, + "loss": 1.1573, + "mean_token_accuracy": 0.7164628058671951, + "num_tokens": 523676.0, + "step": 1 + }, + { + "epoch": 0.010914051841746248, + "grad_norm": 10.273668310890756, + "learning_rate": 1.0526315789473685e-06, + "loss": 1.1616, + "mean_token_accuracy": 0.7147120386362076, + "num_tokens": 1045809.0, + "step": 2 + }, + { + "epoch": 0.01637107776261937, + "grad_norm": 9.78774768102665, + "learning_rate": 2.105263157894737e-06, + "loss": 1.1567, + "mean_token_accuracy": 0.7149067521095276, + "num_tokens": 1570097.0, + "step": 3 + }, + { + "epoch": 0.021828103683492497, + "grad_norm": 8.94313097518904, + "learning_rate": 3.157894736842105e-06, + "loss": 1.1348, + "mean_token_accuracy": 0.7210559844970703, + "num_tokens": 2091845.0, + "step": 4 + }, + { + "epoch": 0.027285129604365622, + "grad_norm": 7.9721339575335834, + "learning_rate": 4.210526315789474e-06, + "loss": 1.1095, + "mean_token_accuracy": 0.7218570560216904, + "num_tokens": 2615847.0, + "step": 5 + }, + { + "epoch": 0.03274215552523874, + "grad_norm": 7.592127321700853, + "learning_rate": 5.263157894736842e-06, + "loss": 1.0809, + "mean_token_accuracy": 0.7219354957342148, + "num_tokens": 3138337.0, + "step": 6 + }, + { + "epoch": 0.03819918144611187, + "grad_norm": 5.919775437809893, + "learning_rate": 6.31578947368421e-06, + "loss": 1.0143, + "mean_token_accuracy": 0.7335961759090424, + "num_tokens": 3661593.0, + "step": 7 + }, + { + "epoch": 0.04365620736698499, + "grad_norm": 5.601592015786545, + "learning_rate": 7.368421052631579e-06, + "loss": 1.0196, + "mean_token_accuracy": 0.7284622341394424, + "num_tokens": 4185249.0, + "step": 8 + }, + { + "epoch": 0.04911323328785812, + "grad_norm": 4.982536388951164, + "learning_rate": 8.421052631578948e-06, + "loss": 0.9759, + "mean_token_accuracy": 0.7368100732564926, + "num_tokens": 4708215.0, + "step": 9 + }, + { + "epoch": 0.054570259208731244, + "grad_norm": 4.249175412307748, + "learning_rate": 9.473684210526315e-06, + "loss": 0.9085, + "mean_token_accuracy": 0.7506365925073624, + "num_tokens": 5230908.0, + "step": 10 + }, + { + "epoch": 0.06002728512960437, + "grad_norm": 2.66891797206269, + "learning_rate": 1.0526315789473684e-05, + "loss": 0.8817, + "mean_token_accuracy": 0.7550760209560394, + "num_tokens": 5753624.0, + "step": 11 + }, + { + "epoch": 0.06548431105047749, + "grad_norm": 1.901679552444703, + "learning_rate": 1.1578947368421053e-05, + "loss": 0.8638, + "mean_token_accuracy": 0.7565864026546478, + "num_tokens": 6276373.0, + "step": 12 + }, + { + "epoch": 0.07094133697135062, + "grad_norm": 1.9750884331299126, + "learning_rate": 1.263157894736842e-05, + "loss": 0.8356, + "mean_token_accuracy": 0.7596608996391296, + "num_tokens": 6800427.0, + "step": 13 + }, + { + "epoch": 0.07639836289222374, + "grad_norm": 2.0247436740822646, + "learning_rate": 1.3684210526315791e-05, + "loss": 0.8165, + "mean_token_accuracy": 0.7632935792207718, + "num_tokens": 7323928.0, + "step": 14 + }, + { + "epoch": 0.08185538881309687, + "grad_norm": 1.2580071009761389, + "learning_rate": 1.4736842105263159e-05, + "loss": 0.7843, + "mean_token_accuracy": 0.7702907770872116, + "num_tokens": 7847456.0, + "step": 15 + }, + { + "epoch": 0.08731241473396999, + "grad_norm": 1.1026372318409008, + "learning_rate": 1.578947368421053e-05, + "loss": 0.7812, + "mean_token_accuracy": 0.7720244079828262, + "num_tokens": 8370723.0, + "step": 16 + }, + { + "epoch": 0.0927694406548431, + "grad_norm": 0.9790512143637813, + "learning_rate": 1.6842105263157896e-05, + "loss": 0.7609, + "mean_token_accuracy": 0.7757307291030884, + "num_tokens": 8894432.0, + "step": 17 + }, + { + "epoch": 0.09822646657571624, + "grad_norm": 0.910145815422614, + "learning_rate": 1.7894736842105264e-05, + "loss": 0.7538, + "mean_token_accuracy": 0.7766335010528564, + "num_tokens": 9418720.0, + "step": 18 + }, + { + "epoch": 0.10368349249658936, + "grad_norm": 0.8433683095975495, + "learning_rate": 1.894736842105263e-05, + "loss": 0.7409, + "mean_token_accuracy": 0.7793708741664886, + "num_tokens": 9942800.0, + "step": 19 + }, + { + "epoch": 0.10914051841746249, + "grad_norm": 0.7670309303084534, + "learning_rate": 2e-05, + "loss": 0.7321, + "mean_token_accuracy": 0.7815710753202438, + "num_tokens": 10466884.0, + "step": 20 + }, + { + "epoch": 0.1145975443383356, + "grad_norm": 0.706537688385123, + "learning_rate": 1.999818745523526e-05, + "loss": 0.697, + "mean_token_accuracy": 0.7905226796865463, + "num_tokens": 10990544.0, + "step": 21 + }, + { + "epoch": 0.12005457025920874, + "grad_norm": 0.5527165898545957, + "learning_rate": 1.999275047800474e-05, + "loss": 0.7128, + "mean_token_accuracy": 0.7862321138381958, + "num_tokens": 11514374.0, + "step": 22 + }, + { + "epoch": 0.12551159618008187, + "grad_norm": 0.571109958623919, + "learning_rate": 1.9983691039261358e-05, + "loss": 0.7241, + "mean_token_accuracy": 0.7830730974674225, + "num_tokens": 12037544.0, + "step": 23 + }, + { + "epoch": 0.13096862210095497, + "grad_norm": 0.5370114426229838, + "learning_rate": 1.9971012423132776e-05, + "loss": 0.6982, + "mean_token_accuracy": 0.7900128364562988, + "num_tokens": 12561482.0, + "step": 24 + }, + { + "epoch": 0.1364256480218281, + "grad_norm": 0.5547056978188413, + "learning_rate": 1.9954719225730847e-05, + "loss": 0.7062, + "mean_token_accuracy": 0.7870719134807587, + "num_tokens": 13085632.0, + "step": 25 + }, + { + "epoch": 0.14188267394270124, + "grad_norm": 0.4977992915278872, + "learning_rate": 1.99348173534855e-05, + "loss": 0.7022, + "mean_token_accuracy": 0.7884464412927628, + "num_tokens": 13608889.0, + "step": 26 + }, + { + "epoch": 0.14733969986357434, + "grad_norm": 0.6284161854646022, + "learning_rate": 1.9911314021003614e-05, + "loss": 0.6929, + "mean_token_accuracy": 0.7912539541721344, + "num_tokens": 14132289.0, + "step": 27 + }, + { + "epoch": 0.15279672578444747, + "grad_norm": 0.5715615291456563, + "learning_rate": 1.9884217748453625e-05, + "loss": 0.6831, + "mean_token_accuracy": 0.7934748083353043, + "num_tokens": 14655501.0, + "step": 28 + }, + { + "epoch": 0.1582537517053206, + "grad_norm": 0.7940147931527408, + "learning_rate": 1.9853538358476933e-05, + "loss": 0.6985, + "mean_token_accuracy": 0.7891672253608704, + "num_tokens": 15178673.0, + "step": 29 + }, + { + "epoch": 0.16371077762619374, + "grad_norm": 0.42070194949907197, + "learning_rate": 1.9819286972627066e-05, + "loss": 0.6789, + "mean_token_accuracy": 0.7936267405748367, + "num_tokens": 15702134.0, + "step": 30 + }, + { + "epoch": 0.16916780354706684, + "grad_norm": 0.42689025398588065, + "learning_rate": 1.9781476007338058e-05, + "loss": 0.6661, + "mean_token_accuracy": 0.7965902388095856, + "num_tokens": 16225916.0, + "step": 31 + }, + { + "epoch": 0.17462482946793997, + "grad_norm": 0.4069071707695385, + "learning_rate": 1.9740119169423337e-05, + "loss": 0.6769, + "mean_token_accuracy": 0.7948543280363083, + "num_tokens": 16749401.0, + "step": 32 + }, + { + "epoch": 0.1800818553888131, + "grad_norm": 0.3899451941069501, + "learning_rate": 1.9695231451106914e-05, + "loss": 0.6849, + "mean_token_accuracy": 0.7926386445760727, + "num_tokens": 17273468.0, + "step": 33 + }, + { + "epoch": 0.1855388813096862, + "grad_norm": 0.38714269519831274, + "learning_rate": 1.964682912458856e-05, + "loss": 0.6689, + "mean_token_accuracy": 0.7962382435798645, + "num_tokens": 17795903.0, + "step": 34 + }, + { + "epoch": 0.19099590723055934, + "grad_norm": 0.3854479116669319, + "learning_rate": 1.9594929736144978e-05, + "loss": 0.6828, + "mean_token_accuracy": 0.792619377374649, + "num_tokens": 18317513.0, + "step": 35 + }, + { + "epoch": 0.19645293315143247, + "grad_norm": 0.36645962752185096, + "learning_rate": 1.9539552099769128e-05, + "loss": 0.6582, + "mean_token_accuracy": 0.7995733767747879, + "num_tokens": 18840309.0, + "step": 36 + }, + { + "epoch": 0.2019099590723056, + "grad_norm": 0.4016622694853883, + "learning_rate": 1.9480716290349998e-05, + "loss": 0.675, + "mean_token_accuracy": 0.7943409830331802, + "num_tokens": 19363271.0, + "step": 37 + }, + { + "epoch": 0.2073669849931787, + "grad_norm": 0.3598753892733462, + "learning_rate": 1.941844363639525e-05, + "loss": 0.6502, + "mean_token_accuracy": 0.8010804206132889, + "num_tokens": 19886861.0, + "step": 38 + }, + { + "epoch": 0.21282401091405184, + "grad_norm": 0.3767009308515463, + "learning_rate": 1.9352756712299467e-05, + "loss": 0.6622, + "mean_token_accuracy": 0.7980271279811859, + "num_tokens": 20410529.0, + "step": 39 + }, + { + "epoch": 0.21828103683492497, + "grad_norm": 0.36380332204946886, + "learning_rate": 1.9283679330160726e-05, + "loss": 0.6743, + "mean_token_accuracy": 0.7954283803701401, + "num_tokens": 20933244.0, + "step": 40 + }, + { + "epoch": 0.22373806275579808, + "grad_norm": 0.3372596025306025, + "learning_rate": 1.92112365311485e-05, + "loss": 0.6603, + "mean_token_accuracy": 0.7978256940841675, + "num_tokens": 21456172.0, + "step": 41 + }, + { + "epoch": 0.2291950886766712, + "grad_norm": 0.3431291193692692, + "learning_rate": 1.913545457642601e-05, + "loss": 0.6474, + "mean_token_accuracy": 0.8019589632749557, + "num_tokens": 21979434.0, + "step": 42 + }, + { + "epoch": 0.23465211459754434, + "grad_norm": 0.35120079952111755, + "learning_rate": 1.905636093763031e-05, + "loss": 0.6526, + "mean_token_accuracy": 0.8002910017967224, + "num_tokens": 22503439.0, + "step": 43 + }, + { + "epoch": 0.24010914051841747, + "grad_norm": 0.32515836450165586, + "learning_rate": 1.8973984286913584e-05, + "loss": 0.669, + "mean_token_accuracy": 0.7958613783121109, + "num_tokens": 23025997.0, + "step": 44 + }, + { + "epoch": 0.24556616643929058, + "grad_norm": 0.3423697459142544, + "learning_rate": 1.8888354486549238e-05, + "loss": 0.6508, + "mean_token_accuracy": 0.8007421046495438, + "num_tokens": 23548909.0, + "step": 45 + }, + { + "epoch": 0.25102319236016374, + "grad_norm": 0.34254341826311147, + "learning_rate": 1.8799502578106533e-05, + "loss": 0.6518, + "mean_token_accuracy": 0.7999228090047836, + "num_tokens": 24071561.0, + "step": 46 + }, + { + "epoch": 0.25648021828103684, + "grad_norm": 0.38246389596908054, + "learning_rate": 1.8707460771197773e-05, + "loss": 0.6643, + "mean_token_accuracy": 0.7966996878385544, + "num_tokens": 24594782.0, + "step": 47 + }, + { + "epoch": 0.26193724420190995, + "grad_norm": 0.36747149421521913, + "learning_rate": 1.861226243180201e-05, + "loss": 0.6424, + "mean_token_accuracy": 0.8020810931921005, + "num_tokens": 25117688.0, + "step": 48 + }, + { + "epoch": 0.2673942701227831, + "grad_norm": 0.3260159545079233, + "learning_rate": 1.8513942070169572e-05, + "loss": 0.6496, + "mean_token_accuracy": 0.8000700622797012, + "num_tokens": 25641600.0, + "step": 49 + }, + { + "epoch": 0.2728512960436562, + "grad_norm": 0.31918709162924563, + "learning_rate": 1.8412535328311813e-05, + "loss": 0.6469, + "mean_token_accuracy": 0.8012609481811523, + "num_tokens": 26165286.0, + "step": 50 + }, + { + "epoch": 0.2783083219645293, + "grad_norm": 0.3884773043904409, + "learning_rate": 1.8308078967080547e-05, + "loss": 0.6453, + "mean_token_accuracy": 0.802044153213501, + "num_tokens": 26688026.0, + "step": 51 + }, + { + "epoch": 0.2837653478854025, + "grad_norm": 0.3221640246981774, + "learning_rate": 1.8200610852841913e-05, + "loss": 0.6447, + "mean_token_accuracy": 0.8017777651548386, + "num_tokens": 27211433.0, + "step": 52 + }, + { + "epoch": 0.2892223738062756, + "grad_norm": 0.35950232000512794, + "learning_rate": 1.8090169943749477e-05, + "loss": 0.6393, + "mean_token_accuracy": 0.8036545515060425, + "num_tokens": 27735386.0, + "step": 53 + }, + { + "epoch": 0.2946793997271487, + "grad_norm": 0.3336227394165688, + "learning_rate": 1.7976796275621556e-05, + "loss": 0.6532, + "mean_token_accuracy": 0.7994280308485031, + "num_tokens": 28259425.0, + "step": 54 + }, + { + "epoch": 0.30013642564802184, + "grad_norm": 0.3159155930753706, + "learning_rate": 1.7860530947427878e-05, + "loss": 0.6367, + "mean_token_accuracy": 0.8045277893543243, + "num_tokens": 28783214.0, + "step": 55 + }, + { + "epoch": 0.30559345156889495, + "grad_norm": 0.33863627132547003, + "learning_rate": 1.7741416106390828e-05, + "loss": 0.6447, + "mean_token_accuracy": 0.8009888529777527, + "num_tokens": 29306632.0, + "step": 56 + }, + { + "epoch": 0.31105047748976805, + "grad_norm": 0.32414058492274905, + "learning_rate": 1.761949493270671e-05, + "loss": 0.6264, + "mean_token_accuracy": 0.8071126639842987, + "num_tokens": 29829363.0, + "step": 57 + }, + { + "epoch": 0.3165075034106412, + "grad_norm": 0.343900605873189, + "learning_rate": 1.7494811623892543e-05, + "loss": 0.6278, + "mean_token_accuracy": 0.806527853012085, + "num_tokens": 30352094.0, + "step": 58 + }, + { + "epoch": 0.3219645293315143, + "grad_norm": 0.36999037823632697, + "learning_rate": 1.736741137876405e-05, + "loss": 0.6284, + "mean_token_accuracy": 0.8057119399309158, + "num_tokens": 30875328.0, + "step": 59 + }, + { + "epoch": 0.3274215552523875, + "grad_norm": 0.3264916869231787, + "learning_rate": 1.72373403810507e-05, + "loss": 0.6335, + "mean_token_accuracy": 0.8045472204685211, + "num_tokens": 31399221.0, + "step": 60 + }, + { + "epoch": 0.3328785811732606, + "grad_norm": 0.4782591852354962, + "learning_rate": 1.710464578265369e-05, + "loss": 0.649, + "mean_token_accuracy": 0.8001821935176849, + "num_tokens": 31923238.0, + "step": 61 + }, + { + "epoch": 0.3383356070941337, + "grad_norm": 0.34512671491394814, + "learning_rate": 1.696937568655294e-05, + "loss": 0.6343, + "mean_token_accuracy": 0.8047307133674622, + "num_tokens": 32446843.0, + "step": 62 + }, + { + "epoch": 0.34379263301500684, + "grad_norm": 0.3795273366749496, + "learning_rate": 1.6831579129369347e-05, + "loss": 0.6232, + "mean_token_accuracy": 0.8063706010580063, + "num_tokens": 32970843.0, + "step": 63 + }, + { + "epoch": 0.34924965893587995, + "grad_norm": 0.37143998348188484, + "learning_rate": 1.6691306063588583e-05, + "loss": 0.6303, + "mean_token_accuracy": 0.8047699332237244, + "num_tokens": 33494468.0, + "step": 64 + }, + { + "epoch": 0.35470668485675305, + "grad_norm": 0.399010736777414, + "learning_rate": 1.6548607339452853e-05, + "loss": 0.6098, + "mean_token_accuracy": 0.8112862259149551, + "num_tokens": 34018179.0, + "step": 65 + }, + { + "epoch": 0.3601637107776262, + "grad_norm": 0.41886782643797665, + "learning_rate": 1.6403534686527223e-05, + "loss": 0.638, + "mean_token_accuracy": 0.8040533065795898, + "num_tokens": 34539809.0, + "step": 66 + }, + { + "epoch": 0.3656207366984993, + "grad_norm": 0.37173400377826826, + "learning_rate": 1.6256140694947217e-05, + "loss": 0.6198, + "mean_token_accuracy": 0.8078815788030624, + "num_tokens": 35063685.0, + "step": 67 + }, + { + "epoch": 0.3710777626193724, + "grad_norm": 0.37848673780253644, + "learning_rate": 1.6106478796354382e-05, + "loss": 0.6357, + "mean_token_accuracy": 0.8034750819206238, + "num_tokens": 35587259.0, + "step": 68 + }, + { + "epoch": 0.3765347885402456, + "grad_norm": 0.40573394210222646, + "learning_rate": 1.595460324452688e-05, + "loss": 0.6291, + "mean_token_accuracy": 0.8048963695764542, + "num_tokens": 36110522.0, + "step": 69 + }, + { + "epoch": 0.3819918144611187, + "grad_norm": 0.32059584094834276, + "learning_rate": 1.5800569095711983e-05, + "loss": 0.6284, + "mean_token_accuracy": 0.8049021810293198, + "num_tokens": 36634670.0, + "step": 70 + }, + { + "epoch": 0.3874488403819918, + "grad_norm": 0.38182900761291677, + "learning_rate": 1.5644432188667695e-05, + "loss": 0.6406, + "mean_token_accuracy": 0.8030418157577515, + "num_tokens": 37157066.0, + "step": 71 + }, + { + "epoch": 0.39290586630286495, + "grad_norm": 0.3792868667331441, + "learning_rate": 1.5486249124420702e-05, + "loss": 0.6216, + "mean_token_accuracy": 0.8081106394529343, + "num_tokens": 37680580.0, + "step": 72 + }, + { + "epoch": 0.39836289222373805, + "grad_norm": 0.3505388571032621, + "learning_rate": 1.5326077245747998e-05, + "loss": 0.6417, + "mean_token_accuracy": 0.8015989065170288, + "num_tokens": 38204191.0, + "step": 73 + }, + { + "epoch": 0.4038199181446112, + "grad_norm": 0.34629977214808816, + "learning_rate": 1.5163974616389621e-05, + "loss": 0.6119, + "mean_token_accuracy": 0.8100632429122925, + "num_tokens": 38727916.0, + "step": 74 + }, + { + "epoch": 0.4092769440654843, + "grad_norm": 0.3432985248757669, + "learning_rate": 1.5000000000000002e-05, + "loss": 0.6189, + "mean_token_accuracy": 0.8073546588420868, + "num_tokens": 39251672.0, + "step": 75 + }, + { + "epoch": 0.4147339699863574, + "grad_norm": 0.34896847344460213, + "learning_rate": 1.4834212838845639e-05, + "loss": 0.6251, + "mean_token_accuracy": 0.8063764572143555, + "num_tokens": 39775687.0, + "step": 76 + }, + { + "epoch": 0.4201909959072306, + "grad_norm": 0.3625496780454411, + "learning_rate": 1.4666673232256738e-05, + "loss": 0.6328, + "mean_token_accuracy": 0.8043248653411865, + "num_tokens": 40298660.0, + "step": 77 + }, + { + "epoch": 0.4256480218281037, + "grad_norm": 0.3532777855472362, + "learning_rate": 1.449744191484066e-05, + "loss": 0.6409, + "mean_token_accuracy": 0.8027763664722443, + "num_tokens": 40822948.0, + "step": 78 + }, + { + "epoch": 0.4311050477489768, + "grad_norm": 0.35365330484423485, + "learning_rate": 1.4326580234465084e-05, + "loss": 0.626, + "mean_token_accuracy": 0.8057489842176437, + "num_tokens": 41345922.0, + "step": 79 + }, + { + "epoch": 0.43656207366984995, + "grad_norm": 0.3349393977670838, + "learning_rate": 1.4154150130018867e-05, + "loss": 0.6175, + "mean_token_accuracy": 0.808809220790863, + "num_tokens": 41869207.0, + "step": 80 + }, + { + "epoch": 0.44201909959072305, + "grad_norm": 0.35933628997277556, + "learning_rate": 1.3980214108958626e-05, + "loss": 0.6277, + "mean_token_accuracy": 0.8056986331939697, + "num_tokens": 42393021.0, + "step": 81 + }, + { + "epoch": 0.44747612551159616, + "grad_norm": 0.33096454453402563, + "learning_rate": 1.380483522464923e-05, + "loss": 0.6282, + "mean_token_accuracy": 0.8058310747146606, + "num_tokens": 42916182.0, + "step": 82 + }, + { + "epoch": 0.4529331514324693, + "grad_norm": 0.30478624197164905, + "learning_rate": 1.362807705350641e-05, + "loss": 0.622, + "mean_token_accuracy": 0.8063958883285522, + "num_tokens": 43438714.0, + "step": 83 + }, + { + "epoch": 0.4583901773533424, + "grad_norm": 0.3501335389226804, + "learning_rate": 1.3450003671949707e-05, + "loss": 0.6316, + "mean_token_accuracy": 0.8037978112697601, + "num_tokens": 43963002.0, + "step": 84 + }, + { + "epoch": 0.4638472032742155, + "grad_norm": 0.3050292683152368, + "learning_rate": 1.3270679633174219e-05, + "loss": 0.6313, + "mean_token_accuracy": 0.8047272562980652, + "num_tokens": 44486140.0, + "step": 85 + }, + { + "epoch": 0.4693042291950887, + "grad_norm": 0.37055749603164684, + "learning_rate": 1.3090169943749475e-05, + "loss": 0.6207, + "mean_token_accuracy": 0.8078664541244507, + "num_tokens": 45009734.0, + "step": 86 + }, + { + "epoch": 0.4747612551159618, + "grad_norm": 0.2984594081324396, + "learning_rate": 1.2908540040053992e-05, + "loss": 0.6432, + "mean_token_accuracy": 0.8011371046304703, + "num_tokens": 45532685.0, + "step": 87 + }, + { + "epoch": 0.48021828103683495, + "grad_norm": 0.3800002310817578, + "learning_rate": 1.2725855764553981e-05, + "loss": 0.618, + "mean_token_accuracy": 0.8085715621709824, + "num_tokens": 46056685.0, + "step": 88 + }, + { + "epoch": 0.48567530695770805, + "grad_norm": 0.330018519060922, + "learning_rate": 1.2542183341934873e-05, + "loss": 0.6319, + "mean_token_accuracy": 0.8046033978462219, + "num_tokens": 46580056.0, + "step": 89 + }, + { + "epoch": 0.49113233287858116, + "grad_norm": 0.3375380842374477, + "learning_rate": 1.2357589355094275e-05, + "loss": 0.6208, + "mean_token_accuracy": 0.8076395243406296, + "num_tokens": 47103434.0, + "step": 90 + }, + { + "epoch": 0.4965893587994543, + "grad_norm": 0.3164441239625524, + "learning_rate": 1.217214072100508e-05, + "loss": 0.6184, + "mean_token_accuracy": 0.8088521063327789, + "num_tokens": 47625849.0, + "step": 91 + }, + { + "epoch": 0.5020463847203275, + "grad_norm": 0.30994257381808427, + "learning_rate": 1.1985904666457455e-05, + "loss": 0.6137, + "mean_token_accuracy": 0.8092475831508636, + "num_tokens": 48149477.0, + "step": 92 + }, + { + "epoch": 0.5075034106412005, + "grad_norm": 0.32961807061620957, + "learning_rate": 1.179894870368854e-05, + "loss": 0.6144, + "mean_token_accuracy": 0.8092338591814041, + "num_tokens": 48672968.0, + "step": 93 + }, + { + "epoch": 0.5129604365620737, + "grad_norm": 0.2954751932341951, + "learning_rate": 1.1611340605908643e-05, + "loss": 0.6293, + "mean_token_accuracy": 0.804235503077507, + "num_tokens": 49196609.0, + "step": 94 + }, + { + "epoch": 0.5184174624829468, + "grad_norm": 0.3073584261153331, + "learning_rate": 1.1423148382732854e-05, + "loss": 0.6073, + "mean_token_accuracy": 0.8108531385660172, + "num_tokens": 49720351.0, + "step": 95 + }, + { + "epoch": 0.5238744884038199, + "grad_norm": 0.31592897133822806, + "learning_rate": 1.1234440255526948e-05, + "loss": 0.6171, + "mean_token_accuracy": 0.8081915378570557, + "num_tokens": 50243733.0, + "step": 96 + }, + { + "epoch": 0.529331514324693, + "grad_norm": 0.34742265287757973, + "learning_rate": 1.1045284632676535e-05, + "loss": 0.6318, + "mean_token_accuracy": 0.8044733256101608, + "num_tokens": 50767754.0, + "step": 97 + }, + { + "epoch": 0.5347885402455662, + "grad_norm": 0.34110294551441445, + "learning_rate": 1.08557500847884e-05, + "loss": 0.6188, + "mean_token_accuracy": 0.8076249808073044, + "num_tokens": 51291988.0, + "step": 98 + }, + { + "epoch": 0.5402455661664393, + "grad_norm": 0.32189183372762564, + "learning_rate": 1.066590531983304e-05, + "loss": 0.6152, + "mean_token_accuracy": 0.8090378940105438, + "num_tokens": 51815310.0, + "step": 99 + }, + { + "epoch": 0.5457025920873124, + "grad_norm": 0.29768781279014606, + "learning_rate": 1.0475819158237426e-05, + "loss": 0.6247, + "mean_token_accuracy": 0.806323915719986, + "num_tokens": 52338013.0, + "step": 100 + }, + { + "epoch": 0.5511596180081856, + "grad_norm": 0.33743263078601365, + "learning_rate": 1.0285560507936962e-05, + "loss": 0.6294, + "mean_token_accuracy": 0.8047656267881393, + "num_tokens": 52860936.0, + "step": 101 + }, + { + "epoch": 0.5566166439290586, + "grad_norm": 0.30513251377145567, + "learning_rate": 1.0095198339395769e-05, + "loss": 0.6136, + "mean_token_accuracy": 0.8092543631792068, + "num_tokens": 53383863.0, + "step": 102 + }, + { + "epoch": 0.5620736698499318, + "grad_norm": 0.2822146986279436, + "learning_rate": 9.904801660604234e-06, + "loss": 0.6096, + "mean_token_accuracy": 0.8101857900619507, + "num_tokens": 53906296.0, + "step": 103 + }, + { + "epoch": 0.567530695770805, + "grad_norm": 0.2846732632438344, + "learning_rate": 9.71443949206304e-06, + "loss": 0.6145, + "mean_token_accuracy": 0.8090496212244034, + "num_tokens": 54429769.0, + "step": 104 + }, + { + "epoch": 0.572987721691678, + "grad_norm": 0.29261503618482076, + "learning_rate": 9.524180841762577e-06, + "loss": 0.6077, + "mean_token_accuracy": 0.8109816312789917, + "num_tokens": 54953622.0, + "step": 105 + }, + { + "epoch": 0.5784447476125512, + "grad_norm": 0.29743888731648277, + "learning_rate": 9.334094680166962e-06, + "loss": 0.6254, + "mean_token_accuracy": 0.8057630807161331, + "num_tokens": 55477635.0, + "step": 106 + }, + { + "epoch": 0.5839017735334243, + "grad_norm": 0.313788538568538, + "learning_rate": 9.144249915211605e-06, + "loss": 0.5959, + "mean_token_accuracy": 0.8136271983385086, + "num_tokens": 56000074.0, + "step": 107 + }, + { + "epoch": 0.5893587994542974, + "grad_norm": 0.28533146520527214, + "learning_rate": 8.954715367323468e-06, + "loss": 0.6167, + "mean_token_accuracy": 0.8088293522596359, + "num_tokens": 56522578.0, + "step": 108 + }, + { + "epoch": 0.5948158253751705, + "grad_norm": 0.30428085575192637, + "learning_rate": 8.765559744473054e-06, + "loss": 0.6113, + "mean_token_accuracy": 0.8102127313613892, + "num_tokens": 57046143.0, + "step": 109 + }, + { + "epoch": 0.6002728512960437, + "grad_norm": 0.2816148674353229, + "learning_rate": 8.576851617267151e-06, + "loss": 0.6024, + "mean_token_accuracy": 0.8121312856674194, + "num_tokens": 57569414.0, + "step": 110 + }, + { + "epoch": 0.6057298772169167, + "grad_norm": 0.2745800257475884, + "learning_rate": 8.388659394091362e-06, + "loss": 0.604, + "mean_token_accuracy": 0.8116618692874908, + "num_tokens": 58093133.0, + "step": 111 + }, + { + "epoch": 0.6111869031377899, + "grad_norm": 0.2852956217506078, + "learning_rate": 8.201051296311462e-06, + "loss": 0.6121, + "mean_token_accuracy": 0.8092280626296997, + "num_tokens": 58616445.0, + "step": 112 + }, + { + "epoch": 0.616643929058663, + "grad_norm": 0.3028924823430585, + "learning_rate": 8.014095333542548e-06, + "loss": 0.6277, + "mean_token_accuracy": 0.8050173074007034, + "num_tokens": 59140029.0, + "step": 113 + }, + { + "epoch": 0.6221009549795361, + "grad_norm": 0.2878705074568443, + "learning_rate": 7.827859278994924e-06, + "loss": 0.6183, + "mean_token_accuracy": 0.8077353686094284, + "num_tokens": 59663111.0, + "step": 114 + }, + { + "epoch": 0.6275579809004093, + "grad_norm": 0.2915197799998971, + "learning_rate": 7.642410644905726e-06, + "loss": 0.6055, + "mean_token_accuracy": 0.8112288415431976, + "num_tokens": 60187141.0, + "step": 115 + }, + { + "epoch": 0.6330150068212824, + "grad_norm": 0.28904450320434766, + "learning_rate": 7.4578166580651335e-06, + "loss": 0.5984, + "mean_token_accuracy": 0.8130005896091461, + "num_tokens": 60711004.0, + "step": 116 + }, + { + "epoch": 0.6384720327421555, + "grad_norm": 0.28028896566316097, + "learning_rate": 7.274144235446024e-06, + "loss": 0.6186, + "mean_token_accuracy": 0.807890847325325, + "num_tokens": 61233923.0, + "step": 117 + }, + { + "epoch": 0.6439290586630286, + "grad_norm": 0.2913180494222617, + "learning_rate": 7.0914599599460095e-06, + "loss": 0.6136, + "mean_token_accuracy": 0.8088892549276352, + "num_tokens": 61758088.0, + "step": 118 + }, + { + "epoch": 0.6493860845839018, + "grad_norm": 0.30378873402714074, + "learning_rate": 6.909830056250527e-06, + "loss": 0.6079, + "mean_token_accuracy": 0.810408428311348, + "num_tokens": 62281964.0, + "step": 119 + }, + { + "epoch": 0.654843110504775, + "grad_norm": 0.2805017521562141, + "learning_rate": 6.729320366825785e-06, + "loss": 0.6015, + "mean_token_accuracy": 0.8118923008441925, + "num_tokens": 62805421.0, + "step": 120 + }, + { + "epoch": 0.660300136425648, + "grad_norm": 0.2900380570316718, + "learning_rate": 6.549996328050296e-06, + "loss": 0.6162, + "mean_token_accuracy": 0.8080224990844727, + "num_tokens": 63327512.0, + "step": 121 + }, + { + "epoch": 0.6657571623465212, + "grad_norm": 0.28161953335473705, + "learning_rate": 6.3719229464935915e-06, + "loss": 0.611, + "mean_token_accuracy": 0.8101419806480408, + "num_tokens": 63851052.0, + "step": 122 + }, + { + "epoch": 0.6712141882673943, + "grad_norm": 0.27911943019979396, + "learning_rate": 6.19516477535077e-06, + "loss": 0.6144, + "mean_token_accuracy": 0.8090348690748215, + "num_tokens": 64373976.0, + "step": 123 + }, + { + "epoch": 0.6766712141882674, + "grad_norm": 0.28869189523406197, + "learning_rate": 6.019785891041381e-06, + "loss": 0.6027, + "mean_token_accuracy": 0.8112694770097733, + "num_tokens": 64897673.0, + "step": 124 + }, + { + "epoch": 0.6821282401091405, + "grad_norm": 0.3116639999006473, + "learning_rate": 5.845849869981137e-06, + "loss": 0.6071, + "mean_token_accuracy": 0.8102435320615768, + "num_tokens": 65420396.0, + "step": 125 + }, + { + "epoch": 0.6875852660300137, + "grad_norm": 0.2794198694229315, + "learning_rate": 5.673419765534915e-06, + "loss": 0.6139, + "mean_token_accuracy": 0.8096896409988403, + "num_tokens": 65944684.0, + "step": 126 + }, + { + "epoch": 0.6930422919508867, + "grad_norm": 0.29207321018083127, + "learning_rate": 5.502558085159344e-06, + "loss": 0.6134, + "mean_token_accuracy": 0.8090117424726486, + "num_tokens": 66467447.0, + "step": 127 + }, + { + "epoch": 0.6984993178717599, + "grad_norm": 0.29215460275378646, + "learning_rate": 5.333326767743263e-06, + "loss": 0.6035, + "mean_token_accuracy": 0.8116814643144608, + "num_tokens": 66990954.0, + "step": 128 + }, + { + "epoch": 0.703956343792633, + "grad_norm": 0.30182577234260344, + "learning_rate": 5.165787161154361e-06, + "loss": 0.6198, + "mean_token_accuracy": 0.8080471009016037, + "num_tokens": 67513914.0, + "step": 129 + }, + { + "epoch": 0.7094133697135061, + "grad_norm": 0.28575689674009214, + "learning_rate": 5.000000000000003e-06, + "loss": 0.6156, + "mean_token_accuracy": 0.8082942962646484, + "num_tokens": 68036824.0, + "step": 130 + }, + { + "epoch": 0.7148703956343793, + "grad_norm": 0.28922672322466253, + "learning_rate": 4.836025383610382e-06, + "loss": 0.6003, + "mean_token_accuracy": 0.8127593100070953, + "num_tokens": 68559651.0, + "step": 131 + }, + { + "epoch": 0.7203274215552524, + "grad_norm": 0.298101675623923, + "learning_rate": 4.673922754252001e-06, + "loss": 0.6168, + "mean_token_accuracy": 0.8083733916282654, + "num_tokens": 69082066.0, + "step": 132 + }, + { + "epoch": 0.7257844474761255, + "grad_norm": 0.2744186495019478, + "learning_rate": 4.513750875579303e-06, + "loss": 0.6054, + "mean_token_accuracy": 0.8108874261379242, + "num_tokens": 69604652.0, + "step": 133 + }, + { + "epoch": 0.7312414733969986, + "grad_norm": 0.2830493473939992, + "learning_rate": 4.355567811332311e-06, + "loss": 0.6124, + "mean_token_accuracy": 0.8090694695711136, + "num_tokens": 70128721.0, + "step": 134 + }, + { + "epoch": 0.7366984993178718, + "grad_norm": 0.26452385051400257, + "learning_rate": 4.19943090428802e-06, + "loss": 0.609, + "mean_token_accuracy": 0.8103781342506409, + "num_tokens": 70651715.0, + "step": 135 + }, + { + "epoch": 0.7421555252387448, + "grad_norm": 0.302575355186892, + "learning_rate": 4.045396755473121e-06, + "loss": 0.6155, + "mean_token_accuracy": 0.8087055832147598, + "num_tokens": 71175451.0, + "step": 136 + }, + { + "epoch": 0.747612551159618, + "grad_norm": 0.29021714961394307, + "learning_rate": 3.893521203645618e-06, + "loss": 0.6107, + "mean_token_accuracy": 0.8100896179676056, + "num_tokens": 71699274.0, + "step": 137 + }, + { + "epoch": 0.7530695770804912, + "grad_norm": 0.26530999183379256, + "learning_rate": 3.743859305052785e-06, + "loss": 0.6012, + "mean_token_accuracy": 0.8125804513692856, + "num_tokens": 72221747.0, + "step": 138 + }, + { + "epoch": 0.7585266030013642, + "grad_norm": 0.2664134356854169, + "learning_rate": 3.596465313472778e-06, + "loss": 0.6049, + "mean_token_accuracy": 0.8113019466400146, + "num_tokens": 72745367.0, + "step": 139 + }, + { + "epoch": 0.7639836289222374, + "grad_norm": 0.2726636629633549, + "learning_rate": 3.4513926605471504e-06, + "loss": 0.6029, + "mean_token_accuracy": 0.8118140697479248, + "num_tokens": 73269309.0, + "step": 140 + }, + { + "epoch": 0.7694406548431105, + "grad_norm": 0.2823783285707402, + "learning_rate": 3.308693936411421e-06, + "loss": 0.5984, + "mean_token_accuracy": 0.8139031380414963, + "num_tokens": 73793330.0, + "step": 141 + }, + { + "epoch": 0.7748976807639836, + "grad_norm": 0.2842494431708805, + "learning_rate": 3.1684208706306572e-06, + "loss": 0.6038, + "mean_token_accuracy": 0.811387911438942, + "num_tokens": 74315676.0, + "step": 142 + }, + { + "epoch": 0.7803547066848567, + "grad_norm": 0.27140851078870026, + "learning_rate": 3.0306243134470668e-06, + "loss": 0.6013, + "mean_token_accuracy": 0.8124092221260071, + "num_tokens": 74839637.0, + "step": 143 + }, + { + "epoch": 0.7858117326057299, + "grad_norm": 0.2798670919794706, + "learning_rate": 2.8953542173463133e-06, + "loss": 0.6106, + "mean_token_accuracy": 0.8100556433200836, + "num_tokens": 75363497.0, + "step": 144 + }, + { + "epoch": 0.791268758526603, + "grad_norm": 0.28033902054535437, + "learning_rate": 2.7626596189492983e-06, + "loss": 0.6027, + "mean_token_accuracy": 0.8115111291408539, + "num_tokens": 75887489.0, + "step": 145 + }, + { + "epoch": 0.7967257844474761, + "grad_norm": 0.27490448958024705, + "learning_rate": 2.6325886212359496e-06, + "loss": 0.6182, + "mean_token_accuracy": 0.8073505163192749, + "num_tokens": 76409584.0, + "step": 146 + }, + { + "epoch": 0.8021828103683493, + "grad_norm": 0.24675144845595406, + "learning_rate": 2.5051883761074613e-06, + "loss": 0.6028, + "mean_token_accuracy": 0.8121288418769836, + "num_tokens": 76931524.0, + "step": 147 + }, + { + "epoch": 0.8076398362892224, + "grad_norm": 0.25519179346777965, + "learning_rate": 2.380505067293293e-06, + "loss": 0.6196, + "mean_token_accuracy": 0.8076845556497574, + "num_tokens": 77454057.0, + "step": 148 + }, + { + "epoch": 0.8130968622100955, + "grad_norm": 0.26185996250871496, + "learning_rate": 2.2585838936091753e-06, + "loss": 0.6017, + "mean_token_accuracy": 0.8125504702329636, + "num_tokens": 77977862.0, + "step": 149 + }, + { + "epoch": 0.8185538881309686, + "grad_norm": 0.26113753260747125, + "learning_rate": 2.1394690525721275e-06, + "loss": 0.6004, + "mean_token_accuracy": 0.8124841898679733, + "num_tokens": 78501396.0, + "step": 150 + }, + { + "epoch": 0.8240109140518418, + "grad_norm": 0.26467961910583754, + "learning_rate": 2.0232037243784475e-06, + "loss": 0.6119, + "mean_token_accuracy": 0.8094252794981003, + "num_tokens": 79024199.0, + "step": 151 + }, + { + "epoch": 0.8294679399727148, + "grad_norm": 0.2644047665450295, + "learning_rate": 1.9098300562505266e-06, + "loss": 0.6004, + "mean_token_accuracy": 0.8127636909484863, + "num_tokens": 79548077.0, + "step": 152 + }, + { + "epoch": 0.834924965893588, + "grad_norm": 0.2516602132329715, + "learning_rate": 1.7993891471580894e-06, + "loss": 0.6205, + "mean_token_accuracy": 0.8068733364343643, + "num_tokens": 80070613.0, + "step": 153 + }, + { + "epoch": 0.8403819918144612, + "grad_norm": 0.25189418048747564, + "learning_rate": 1.6919210329194535e-06, + "loss": 0.5925, + "mean_token_accuracy": 0.81495700776577, + "num_tokens": 80593271.0, + "step": 154 + }, + { + "epoch": 0.8458390177353342, + "grad_norm": 0.2541947348763928, + "learning_rate": 1.587464671688187e-06, + "loss": 0.6052, + "mean_token_accuracy": 0.8111275136470795, + "num_tokens": 81117027.0, + "step": 155 + }, + { + "epoch": 0.8512960436562074, + "grad_norm": 0.2485428118091583, + "learning_rate": 1.4860579298304311e-06, + "loss": 0.6027, + "mean_token_accuracy": 0.8118527084589005, + "num_tokens": 81639970.0, + "step": 156 + }, + { + "epoch": 0.8567530695770805, + "grad_norm": 0.2558642894347357, + "learning_rate": 1.3877375681979944e-06, + "loss": 0.6179, + "mean_token_accuracy": 0.8081785440444946, + "num_tokens": 82164220.0, + "step": 157 + }, + { + "epoch": 0.8622100954979536, + "grad_norm": 0.2616239161200991, + "learning_rate": 1.2925392288022299e-06, + "loss": 0.6006, + "mean_token_accuracy": 0.8130854815244675, + "num_tokens": 82687126.0, + "step": 158 + }, + { + "epoch": 0.8676671214188267, + "grad_norm": 0.2503849676832415, + "learning_rate": 1.2004974218934695e-06, + "loss": 0.6121, + "mean_token_accuracy": 0.8090514242649078, + "num_tokens": 83211082.0, + "step": 159 + }, + { + "epoch": 0.8731241473396999, + "grad_norm": 0.25569881744843687, + "learning_rate": 1.1116455134507665e-06, + "loss": 0.5978, + "mean_token_accuracy": 0.8128385543823242, + "num_tokens": 83734481.0, + "step": 160 + }, + { + "epoch": 0.878581173260573, + "grad_norm": 0.258101439807178, + "learning_rate": 1.0260157130864178e-06, + "loss": 0.5997, + "mean_token_accuracy": 0.8123024553060532, + "num_tokens": 84258601.0, + "step": 161 + }, + { + "epoch": 0.8840381991814461, + "grad_norm": 0.25927293039307325, + "learning_rate": 9.436390623696911e-07, + "loss": 0.6124, + "mean_token_accuracy": 0.8082059770822525, + "num_tokens": 84781078.0, + "step": 162 + }, + { + "epoch": 0.8894952251023193, + "grad_norm": 0.2526035772284644, + "learning_rate": 8.645454235739903e-07, + "loss": 0.5948, + "mean_token_accuracy": 0.8139047026634216, + "num_tokens": 85303323.0, + "step": 163 + }, + { + "epoch": 0.8949522510231923, + "grad_norm": 0.24869120811837703, + "learning_rate": 7.887634688515e-07, + "loss": 0.6007, + "mean_token_accuracy": 0.8121525943279266, + "num_tokens": 85826587.0, + "step": 164 + }, + { + "epoch": 0.9004092769440655, + "grad_norm": 0.2585783712565244, + "learning_rate": 7.163206698392744e-07, + "loss": 0.6067, + "mean_token_accuracy": 0.8101864755153656, + "num_tokens": 86348762.0, + "step": 165 + }, + { + "epoch": 0.9058663028649386, + "grad_norm": 0.2518620468900349, + "learning_rate": 6.472432877005341e-07, + "loss": 0.5887, + "mean_token_accuracy": 0.8158304989337921, + "num_tokens": 86871300.0, + "step": 166 + }, + { + "epoch": 0.9113233287858117, + "grad_norm": 0.244004632190641, + "learning_rate": 5.815563636047539e-07, + "loss": 0.6072, + "mean_token_accuracy": 0.810369223356247, + "num_tokens": 87394281.0, + "step": 167 + }, + { + "epoch": 0.9167803547066848, + "grad_norm": 0.24445608373025496, + "learning_rate": 5.192837096500058e-07, + "loss": 0.5996, + "mean_token_accuracy": 0.8120895475149155, + "num_tokens": 87918100.0, + "step": 168 + }, + { + "epoch": 0.922237380627558, + "grad_norm": 0.2513116743689248, + "learning_rate": 4.6044790023087373e-07, + "loss": 0.6019, + "mean_token_accuracy": 0.812640592455864, + "num_tokens": 88441158.0, + "step": 169 + }, + { + "epoch": 0.927694406548431, + "grad_norm": 0.24582299262926002, + "learning_rate": 4.0507026385502747e-07, + "loss": 0.6141, + "mean_token_accuracy": 0.8093185424804688, + "num_tokens": 88964881.0, + "step": 170 + }, + { + "epoch": 0.9331514324693042, + "grad_norm": 0.24125628809062805, + "learning_rate": 3.531708754114438e-07, + "loss": 0.5934, + "mean_token_accuracy": 0.8144457191228867, + "num_tokens": 89487871.0, + "step": 171 + }, + { + "epoch": 0.9386084583901774, + "grad_norm": 0.24933232718734447, + "learning_rate": 3.0476854889308737e-07, + "loss": 0.6121, + "mean_token_accuracy": 0.8091763854026794, + "num_tokens": 90011304.0, + "step": 172 + }, + { + "epoch": 0.9440654843110505, + "grad_norm": 0.24998307613649684, + "learning_rate": 2.5988083057666534e-07, + "loss": 0.6128, + "mean_token_accuracy": 0.8095022439956665, + "num_tokens": 90535512.0, + "step": 173 + }, + { + "epoch": 0.9495225102319236, + "grad_norm": 0.2603433846994922, + "learning_rate": 2.1852399266194312e-07, + "loss": 0.6096, + "mean_token_accuracy": 0.8097108900547028, + "num_tokens": 91057921.0, + "step": 174 + }, + { + "epoch": 0.9549795361527967, + "grad_norm": 0.2472084519543789, + "learning_rate": 1.8071302737293294e-07, + "loss": 0.6117, + "mean_token_accuracy": 0.8093229830265045, + "num_tokens": 91581472.0, + "step": 175 + }, + { + "epoch": 0.9604365620736699, + "grad_norm": 0.24695647352252775, + "learning_rate": 1.464616415230702e-07, + "loss": 0.6055, + "mean_token_accuracy": 0.810840904712677, + "num_tokens": 92105494.0, + "step": 176 + }, + { + "epoch": 0.965893587994543, + "grad_norm": 0.24599551241247092, + "learning_rate": 1.1578225154637579e-07, + "loss": 0.6038, + "mean_token_accuracy": 0.8120936304330826, + "num_tokens": 92629136.0, + "step": 177 + }, + { + "epoch": 0.9713506139154161, + "grad_norm": 0.25292592350349663, + "learning_rate": 8.868597899638897e-08, + "loss": 0.5969, + "mean_token_accuracy": 0.813206359744072, + "num_tokens": 93153344.0, + "step": 178 + }, + { + "epoch": 0.9768076398362893, + "grad_norm": 0.24220493608697735, + "learning_rate": 6.51826465144978e-08, + "loss": 0.6064, + "mean_token_accuracy": 0.8107435554265976, + "num_tokens": 93677179.0, + "step": 179 + }, + { + "epoch": 0.9822646657571623, + "grad_norm": 0.251125383266427, + "learning_rate": 4.528077426915412e-08, + "loss": 0.6135, + "mean_token_accuracy": 0.8094018846750259, + "num_tokens": 94199742.0, + "step": 180 + }, + { + "epoch": 0.9877216916780355, + "grad_norm": 0.24330075967604267, + "learning_rate": 2.898757686722542e-08, + "loss": 0.6076, + "mean_token_accuracy": 0.8115710318088531, + "num_tokens": 94723291.0, + "step": 181 + }, + { + "epoch": 0.9931787175989086, + "grad_norm": 0.25502156218936445, + "learning_rate": 1.630896073864352e-08, + "loss": 0.6099, + "mean_token_accuracy": 0.8100082129240036, + "num_tokens": 95247067.0, + "step": 182 + }, + { + "epoch": 0.9986357435197817, + "grad_norm": 0.24338031531558027, + "learning_rate": 7.2495219952639636e-09, + "loss": 0.5994, + "mean_token_accuracy": 0.8127514272928238, + "num_tokens": 95771253.0, + "step": 183 + }, + { + "epoch": 1.0, + "grad_norm": 0.24338031531558027, + "learning_rate": 1.8125447647421302e-09, + "loss": 0.5996, + "mean_token_accuracy": 0.8128812313079834, + "num_tokens": 95902052.0, + "step": 184 + }, + { + "epoch": 1.0, + "step": 184, + "total_flos": 83412022984704.0, + "train_loss": 0.6596692787564319, + "train_runtime": 4532.2478, + "train_samples_per_second": 20.681, + "train_steps_per_second": 0.041 + } + ], + "logging_steps": 1, + "max_steps": 184, + "num_input_tokens_seen": 0, + "num_train_epochs": 1, + "save_steps": 500, + "stateful_callbacks": { + "TrainerControl": { + "args": { + "should_epoch_stop": false, + "should_evaluate": false, + "should_log": false, + "should_save": true, + "should_training_stop": true + }, + "attributes": {} + } + }, + "total_flos": 83412022984704.0, + "train_batch_size": 4, + "trial_name": null, + "trial_params": null +} diff --git a/zero_to_fp32.py b/zero_to_fp32.py new file mode 100644 index 0000000000000000000000000000000000000000..0e759146cadd92ddfefab3680146c2bd6a2b5c04 --- /dev/null +++ b/zero_to_fp32.py @@ -0,0 +1,760 @@ +#!/usr/bin/env python + +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +# This script extracts fp32 consolidated weights from a zero 1, 2 and 3 DeepSpeed checkpoints. It gets +# copied into the top level checkpoint dir, so the user can easily do the conversion at any point in +# the future. Once extracted, the weights don't require DeepSpeed and can be used in any +# application. +# +# example: +# python zero_to_fp32.py . output_dir/ +# or +# python zero_to_fp32.py . output_dir/ --safe_serialization + +import argparse +import torch +import glob +import math +import os +import re +import gc +import json +import numpy as np +from tqdm import tqdm +from collections import OrderedDict +from dataclasses import dataclass + +# while this script doesn't use deepspeed to recover data, since the checkpoints are pickled with +# DeepSpeed data structures it has to be available in the current python environment. +from deepspeed.utils import logger +from deepspeed.checkpoint.constants import (DS_VERSION, OPTIMIZER_STATE_DICT, SINGLE_PARTITION_OF_FP32_GROUPS, + FP32_FLAT_GROUPS, ZERO_STAGE, PARTITION_COUNT, PARAM_SHAPES, BUFFER_NAMES, + FROZEN_PARAM_SHAPES, FROZEN_PARAM_FRAGMENTS) + + +@dataclass +class zero_model_state: + buffers: dict() + param_shapes: dict() + shared_params: list + ds_version: int + frozen_param_shapes: dict() + frozen_param_fragments: dict() + + +debug = 0 + +# load to cpu +device = torch.device('cpu') + + +def atoi(text): + return int(text) if text.isdigit() else text + + +def natural_keys(text): + ''' + alist.sort(key=natural_keys) sorts in human order + http://nedbatchelder.com/blog/200712/human_sorting.html + (See Toothy's implementation in the comments) + ''' + return [atoi(c) for c in re.split(r'(\d+)', text)] + + +def get_model_state_file(checkpoint_dir, zero_stage): + if not os.path.isdir(checkpoint_dir): + raise FileNotFoundError(f"Directory '{checkpoint_dir}' doesn't exist") + + # there should be only one file + if zero_stage <= 2: + file = os.path.join(checkpoint_dir, "mp_rank_00_model_states.pt") + elif zero_stage == 3: + file = os.path.join(checkpoint_dir, "zero_pp_rank_0_mp_rank_00_model_states.pt") + + if not os.path.exists(file): + raise FileNotFoundError(f"can't find model states file at '{file}'") + + return file + + +def get_checkpoint_files(checkpoint_dir, glob_pattern): + # XXX: need to test that this simple glob rule works for multi-node setup too + ckpt_files = sorted(glob.glob(os.path.join(checkpoint_dir, glob_pattern)), key=natural_keys) + + if len(ckpt_files) == 0: + raise FileNotFoundError(f"can't find {glob_pattern} files in directory '{checkpoint_dir}'") + + return ckpt_files + + +def get_optim_files(checkpoint_dir): + return get_checkpoint_files(checkpoint_dir, "*_optim_states.pt") + + +def get_model_state_files(checkpoint_dir): + return get_checkpoint_files(checkpoint_dir, "*_model_states.pt") + + +def parse_model_states(files): + zero_model_states = [] + for file in files: + state_dict = torch.load(file, map_location=device, weights_only=False) + + if BUFFER_NAMES not in state_dict: + raise ValueError(f"{file} is not a model state checkpoint") + buffer_names = state_dict[BUFFER_NAMES] + if debug: + print("Found buffers:", buffer_names) + + # recover just the buffers while restoring them to fp32 if they were saved in fp16 + buffers = {k: v.float() for k, v in state_dict["module"].items() if k in buffer_names} + param_shapes = state_dict[PARAM_SHAPES] + + # collect parameters that are included in param_shapes + param_names = [] + for s in param_shapes: + for name in s.keys(): + param_names.append(name) + + # update with frozen parameters + frozen_param_shapes = state_dict.get(FROZEN_PARAM_SHAPES, None) + if frozen_param_shapes is not None: + if debug: + print(f"Found frozen_param_shapes: {frozen_param_shapes}") + param_names += list(frozen_param_shapes.keys()) + + # handle shared params + shared_params = [[k, v] for k, v in state_dict["shared_params"].items()] + + ds_version = state_dict.get(DS_VERSION, None) + + frozen_param_fragments = state_dict.get(FROZEN_PARAM_FRAGMENTS, None) + + z_model_state = zero_model_state(buffers=buffers, + param_shapes=param_shapes, + shared_params=shared_params, + ds_version=ds_version, + frozen_param_shapes=frozen_param_shapes, + frozen_param_fragments=frozen_param_fragments) + zero_model_states.append(z_model_state) + + return zero_model_states + + +def parse_optim_states(files, ds_checkpoint_dir): + total_files = len(files) + state_dicts = [] + for f in tqdm(files, desc='Loading checkpoint shards'): + state_dict = torch.load(f, map_location=device, mmap=True, weights_only=False) + # immediately discard the potentially huge 2 optimizer states as we only care for fp32 master weights + # and also handle the case where it was already removed by another helper script + state_dict["optimizer_state_dict"].pop("optimizer_state_dict", None) + state_dicts.append(state_dict) + + if not ZERO_STAGE in state_dicts[0][OPTIMIZER_STATE_DICT]: + raise ValueError(f"{files[0]} is not a zero checkpoint") + zero_stage = state_dicts[0][OPTIMIZER_STATE_DICT][ZERO_STAGE] + world_size = state_dicts[0][OPTIMIZER_STATE_DICT][PARTITION_COUNT] + + # For ZeRO-2 each param group can have different partition_count as data parallelism for expert + # parameters can be different from data parallelism for non-expert parameters. So we can just + # use the max of the partition_count to get the dp world_size. + + if type(world_size) is list: + world_size = max(world_size) + + if world_size != total_files: + raise ValueError( + f"Expected {world_size} of '*_optim_states.pt' under '{ds_checkpoint_dir}' but found {total_files} files. " + "Possibly due to an overwrite of an old checkpoint, or a checkpoint didn't get saved by one or more processes." + ) + + # the groups are named differently in each stage + if zero_stage <= 2: + fp32_groups_key = SINGLE_PARTITION_OF_FP32_GROUPS + elif zero_stage == 3: + fp32_groups_key = FP32_FLAT_GROUPS + else: + raise ValueError(f"unknown zero stage {zero_stage}") + + fp32_flat_groups = [state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key] for i in range(len(state_dicts))] + return zero_stage, world_size, fp32_flat_groups + + +def _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir, exclude_frozen_parameters): + """ + Returns fp32 state_dict reconstructed from ds checkpoint + + Args: + - ``ds_checkpoint_dir``: path to the deepspeed checkpoint folder (where the optimizer files are) + + """ + print(f"Processing zero checkpoint '{ds_checkpoint_dir}'") + + optim_files = get_optim_files(ds_checkpoint_dir) + zero_stage, world_size, fp32_flat_groups = parse_optim_states(optim_files, ds_checkpoint_dir) + print(f"Detected checkpoint of type zero stage {zero_stage}, world_size: {world_size}") + + model_files = get_model_state_files(ds_checkpoint_dir) + + zero_model_states = parse_model_states(model_files) + print(f'Parsing checkpoint created by deepspeed=={zero_model_states[0].ds_version}') + + if zero_stage <= 2: + return _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states, + exclude_frozen_parameters) + elif zero_stage == 3: + return _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states, + exclude_frozen_parameters) + + +def _zero2_merge_frozen_params(state_dict, zero_model_states): + if zero_model_states[0].frozen_param_shapes is None or len(zero_model_states[0].frozen_param_shapes) == 0: + return + + frozen_param_shapes = zero_model_states[0].frozen_param_shapes + frozen_param_fragments = zero_model_states[0].frozen_param_fragments + + if debug: + num_elem = sum(s.numel() for s in frozen_param_shapes.values()) + print(f'rank 0: {FROZEN_PARAM_SHAPES}.numel = {num_elem}') + + wanted_params = len(frozen_param_shapes) + wanted_numel = sum(s.numel() for s in frozen_param_shapes.values()) + avail_numel = sum([p.numel() for p in frozen_param_fragments.values()]) + print(f'Frozen params: Have {avail_numel} numels to process.') + print(f'Frozen params: Need {wanted_numel} numels in {wanted_params} params') + + total_params = 0 + total_numel = 0 + for name, shape in frozen_param_shapes.items(): + total_params += 1 + unpartitioned_numel = shape.numel() + total_numel += unpartitioned_numel + + state_dict[name] = frozen_param_fragments[name] + + if debug: + print(f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} ") + + print(f"Reconstructed Frozen fp32 state dict with {total_params} params {total_numel} elements") + + +def _has_callable(obj, fn): + attr = getattr(obj, fn, None) + return callable(attr) + + +def _zero2_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states): + param_shapes = zero_model_states[0].param_shapes + + # Reconstruction protocol: + # + # XXX: document this + + if debug: + for i in range(world_size): + for j in range(len(fp32_flat_groups[0])): + print(f"{FP32_FLAT_GROUPS}[{i}][{j}].shape={fp32_flat_groups[i][j].shape}") + + # XXX: memory usage doubles here (zero2) + num_param_groups = len(fp32_flat_groups[0]) + merged_single_partition_of_fp32_groups = [] + for i in range(num_param_groups): + merged_partitions = [sd[i] for sd in fp32_flat_groups] + full_single_fp32_vector = torch.cat(merged_partitions, 0) + merged_single_partition_of_fp32_groups.append(full_single_fp32_vector) + avail_numel = sum( + [full_single_fp32_vector.numel() for full_single_fp32_vector in merged_single_partition_of_fp32_groups]) + + if debug: + wanted_params = sum([len(shapes) for shapes in param_shapes]) + wanted_numel = sum([sum(shape.numel() for shape in shapes.values()) for shapes in param_shapes]) + # not asserting if there is a mismatch due to possible padding + print(f"Have {avail_numel} numels to process.") + print(f"Need {wanted_numel} numels in {wanted_params} params.") + + # params + # XXX: for huge models that can't fit into the host's RAM we will have to recode this to support + # out-of-core computing solution + total_numel = 0 + total_params = 0 + for shapes, full_single_fp32_vector in zip(param_shapes, merged_single_partition_of_fp32_groups): + offset = 0 + avail_numel = full_single_fp32_vector.numel() + for name, shape in shapes.items(): + + unpartitioned_numel = shape.numel() if _has_callable(shape, 'numel') else math.prod(shape) + total_numel += unpartitioned_numel + total_params += 1 + + if debug: + print(f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} ") + state_dict[name] = full_single_fp32_vector.narrow(0, offset, unpartitioned_numel).view(shape) + offset += unpartitioned_numel + + # Z2 started to align to 2*world_size to improve nccl performance. Therefore both offset and + # avail_numel can differ by anywhere between 0..2*world_size. Due to two unrelated complex + # paddings performed in the code it's almost impossible to predict the exact numbers w/o the + # live optimizer object, so we are checking that the numbers are within the right range + align_to = 2 * world_size + + def zero2_align(x): + return align_to * math.ceil(x / align_to) + + if debug: + print(f"original offset={offset}, avail_numel={avail_numel}") + + offset = zero2_align(offset) + avail_numel = zero2_align(avail_numel) + + if debug: + print(f"aligned offset={offset}, avail_numel={avail_numel}") + + # Sanity check + if offset != avail_numel: + raise ValueError(f"consumed {offset} numels out of {avail_numel} - something is wrong") + + print(f"Reconstructed fp32 state dict with {total_params} params {total_numel} elements") + + +def _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states, + exclude_frozen_parameters): + state_dict = OrderedDict() + + # buffers + buffers = zero_model_states[0].buffers + state_dict.update(buffers) + if debug: + print(f"added {len(buffers)} buffers") + + if not exclude_frozen_parameters: + _zero2_merge_frozen_params(state_dict, zero_model_states) + + _zero2_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states) + + # recover shared parameters + for pair in zero_model_states[0].shared_params: + if pair[1] in state_dict: + state_dict[pair[0]] = state_dict[pair[1]] + + return state_dict + + +def zero3_partitioned_param_info(unpartitioned_numel, world_size): + remainder = unpartitioned_numel % world_size + padding_numel = (world_size - remainder) if remainder else 0 + partitioned_numel = math.ceil(unpartitioned_numel / world_size) + return partitioned_numel, padding_numel + + +def _zero3_merge_frozen_params(state_dict, world_size, zero_model_states): + if zero_model_states[0].frozen_param_shapes is None or len(zero_model_states[0].frozen_param_shapes) == 0: + return + + if debug: + for i in range(world_size): + num_elem = sum(s.numel() for s in zero_model_states[i].frozen_param_fragments.values()) + print(f'rank {i}: {FROZEN_PARAM_SHAPES}.numel = {num_elem}') + + frozen_param_shapes = zero_model_states[0].frozen_param_shapes + wanted_params = len(frozen_param_shapes) + wanted_numel = sum(s.numel() for s in frozen_param_shapes.values()) + avail_numel = sum([p.numel() for p in zero_model_states[0].frozen_param_fragments.values()]) * world_size + print(f'Frozen params: Have {avail_numel} numels to process.') + print(f'Frozen params: Need {wanted_numel} numels in {wanted_params} params') + + total_params = 0 + total_numel = 0 + for name, shape in zero_model_states[0].frozen_param_shapes.items(): + total_params += 1 + unpartitioned_numel = shape.numel() + total_numel += unpartitioned_numel + + param_frags = tuple(model_state.frozen_param_fragments[name] for model_state in zero_model_states) + state_dict[name] = torch.cat(param_frags, 0).narrow(0, 0, unpartitioned_numel).view(shape) + + partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(unpartitioned_numel, world_size) + + if debug: + print( + f"Frozen params: {total_params} {name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}" + ) + + print(f"Reconstructed Frozen fp32 state dict with {total_params} params {total_numel} elements") + + +class GatheredTensor: + """ + A pseudo tensor that collects partitioned weights. + It is more memory efficient when there are multiple groups. + """ + + def __init__(self, flat_groups, flat_groups_offset, offset, partitioned_numel, shape): + self.flat_groups = flat_groups + self.flat_groups_offset = flat_groups_offset + self.offset = offset + self.partitioned_numel = partitioned_numel + self.shape = shape + self.dtype = self.flat_groups[0][0].dtype + + def contiguous(self): + """ + Merge partitioned weights from flat_groups into a single tensor. + """ + end_idx = self.offset + self.partitioned_numel + world_size = len(self.flat_groups) + pad_flat_param_chunks = [] + + for rank_i in range(world_size): + # for each rank, we need to collect weights from related group/groups + flat_groups_at_rank_i = self.flat_groups[rank_i] + start_group_id = None + end_group_id = None + for group_id in range(len(self.flat_groups_offset)): + if self.flat_groups_offset[group_id] <= self.offset < self.flat_groups_offset[group_id + 1]: + start_group_id = group_id + if self.flat_groups_offset[group_id] < end_idx <= self.flat_groups_offset[group_id + 1]: + end_group_id = group_id + break + # collect weights from related group/groups + for group_id in range(start_group_id, end_group_id + 1): + flat_tensor = flat_groups_at_rank_i[group_id] + start_offset = self.offset - self.flat_groups_offset[group_id] + end_offset = min(end_idx, self.flat_groups_offset[group_id + 1]) - self.flat_groups_offset[group_id] + pad_flat_param_chunks.append(flat_tensor[start_offset:end_offset]) + + # collect weights from all ranks + pad_flat_param = torch.cat(pad_flat_param_chunks, dim=0) + param = pad_flat_param[:self.shape.numel()].view(self.shape).contiguous() + return param + + +def _zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states): + param_shapes = zero_model_states[0].param_shapes + avail_numel = sum([flat_group.numel() for flat_group in fp32_flat_groups[0]]) * world_size + + # Reconstruction protocol: For zero3 we need to zip the partitions together at boundary of each + # param, re-consolidating each param, while dealing with padding if any + + # merge list of dicts, preserving order + param_shapes = {k: v for d in param_shapes for k, v in d.items()} + + if debug: + for i in range(world_size): + print(f"{FP32_FLAT_GROUPS}[{i}].shape={fp32_flat_groups[i].shape}") + + wanted_params = len(param_shapes) + wanted_numel = sum(shape.numel() for shape in param_shapes.values()) + # not asserting if there is a mismatch due to possible padding + avail_numel = fp32_flat_groups[0].numel() * world_size + print(f"Trainable params: Have {avail_numel} numels to process.") + print(f"Trainable params: Need {wanted_numel} numels in {wanted_params} params.") + + # params + # XXX: for huge models that can't fit into the host's RAM we will have to recode this to support + # out-of-core computing solution + offset = 0 + total_numel = 0 + total_params = 0 + flat_groups_offset = [0] + list(np.cumsum([flat_tensor.numel() for flat_tensor in fp32_flat_groups[0]])) + for name, shape in tqdm(param_shapes.items(), desc='Gathering sharded weights'): + unpartitioned_numel = shape.numel() + total_numel += unpartitioned_numel + total_params += 1 + partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(unpartitioned_numel, world_size) + + if debug: + print( + f"Trainable params: {total_params} {name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}" + ) + + # memory efficient tensor + tensor = GatheredTensor(fp32_flat_groups, flat_groups_offset, offset, partitioned_numel, shape) + state_dict[name] = tensor + offset += partitioned_numel + + offset *= world_size + + # Sanity check + if offset != avail_numel: + raise ValueError(f"consumed {offset} numels out of {avail_numel} - something is wrong") + + print(f"Reconstructed Trainable fp32 state dict with {total_params} params {total_numel} elements") + + +def _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states, + exclude_frozen_parameters): + state_dict = OrderedDict() + + # buffers + buffers = zero_model_states[0].buffers + state_dict.update(buffers) + if debug: + print(f"added {len(buffers)} buffers") + + if not exclude_frozen_parameters: + _zero3_merge_frozen_params(state_dict, world_size, zero_model_states) + + _zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states) + + # recover shared parameters + for pair in zero_model_states[0].shared_params: + if pair[1] in state_dict: + state_dict[pair[0]] = state_dict[pair[1]] + + return state_dict + + +def to_torch_tensor(state_dict, return_empty_tensor=False): + """ + Convert state_dict of GatheredTensor to torch tensor + """ + torch_state_dict = {} + converted_tensors = {} + for name, tensor in state_dict.items(): + tensor_id = id(tensor) + if tensor_id in converted_tensors: # shared tensors + shared_tensor = torch_state_dict[converted_tensors[tensor_id]] + torch_state_dict[name] = shared_tensor + else: + converted_tensors[tensor_id] = name + if return_empty_tensor: + torch_state_dict[name] = torch.empty(tensor.shape, dtype=tensor.dtype) + else: + torch_state_dict[name] = tensor.contiguous() + return torch_state_dict + + +def get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, + tag=None, + exclude_frozen_parameters=False, + lazy_mode=False): + """ + Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated state_dict that can be loaded with + ``load_state_dict()`` and used for training without DeepSpeed or shared with others, for example + via a model hub. + + Args: + - ``checkpoint_dir``: path to the desired checkpoint folder + - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in 'latest' file. e.g., ``global_step14`` + - ``exclude_frozen_parameters``: exclude frozen parameters + - ``lazy_mode``: get state_dict in lazy mode. It returns a dict of pesduo tensor instead of torch tensor, which is more memory efficient. + Convert the pesduo tensor to torch tensor by ``.contiguous()`` + + Returns: + - pytorch ``state_dict`` + + A typical usage might be :: + + from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint + # do the training and checkpoint saving + state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir) # already on cpu + model = model.cpu() # move to cpu + model.load_state_dict(state_dict) + # submit to model hub or save the model to share with others + + In this example the ``model`` will no longer be usable in the deepspeed context of the same + application. i.e. you will need to re-initialize the deepspeed engine, since + ``model.load_state_dict(state_dict)`` will remove all the deepspeed magic from it. + + If you want it all done for you, use ``load_state_dict_from_zero_checkpoint`` instead. + + Note: the above usage may not work if your application doesn't have sufficient free CPU memory. + You may need to use the offline approach using the ``zero_to_fp32.py`` script that is saved with + the checkpoint. Or you can load state_dict in lazy mode :: + + from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint + state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, lazy_mode=True) # not on cpu + for name, lazy_tensor in state_dict.item(): + tensor = lazy_tensor.contiguous() # to cpu + print(name, tensor) + # del tensor to release memory if it no longer in use + """ + if tag is None: + latest_path = os.path.join(checkpoint_dir, 'latest') + if os.path.isfile(latest_path): + with open(latest_path, 'r') as fd: + tag = fd.read().strip() + else: + raise ValueError(f"Unable to find 'latest' file at {latest_path}") + + ds_checkpoint_dir = os.path.join(checkpoint_dir, tag) + + if not os.path.isdir(ds_checkpoint_dir): + raise FileNotFoundError(f"Directory '{ds_checkpoint_dir}' doesn't exist") + + state_dict = _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir, exclude_frozen_parameters) + if lazy_mode: + return state_dict + else: + return to_torch_tensor(state_dict) + + +def convert_zero_checkpoint_to_fp32_state_dict(checkpoint_dir, + output_dir, + max_shard_size="5GB", + safe_serialization=False, + tag=None, + exclude_frozen_parameters=False): + """ + Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict`` file that can be + loaded with ``torch.load(file)`` + ``load_state_dict()`` and used for training without DeepSpeed. + + Args: + - ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``) + - ``output_dir``: directory to the pytorch fp32 state_dict output files + - ``max_shard_size``: the maximum size for a checkpoint before being sharded, default value is 5GB + - ``safe_serialization``: whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`). + - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14`` + - ``exclude_frozen_parameters``: exclude frozen parameters + """ + + # Dependency pre-check + if safe_serialization: + try: + from safetensors.torch import save_file + except ImportError: + print('If you want to use `safe_serialization`, please `pip install safetensors`') + raise + if max_shard_size is not None: + try: + from huggingface_hub import split_torch_state_dict_into_shards + except ImportError: + print('If you want to use `max_shard_size`, please `pip install huggingface_hub`') + raise + + # Convert zero checkpoint to state_dict + state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, + tag, + exclude_frozen_parameters, + lazy_mode=True) + + # Shard the model if it is too big. + weights_name = "model.safetensors" if safe_serialization else "pytorch_model.bin" + if max_shard_size is not None: + filename_pattern = weights_name.replace(".bin", "{suffix}.bin").replace(".safetensors", "{suffix}.safetensors") + # an memory-efficient approach for sharding + empty_state_dict = to_torch_tensor(state_dict, return_empty_tensor=True) + state_dict_split = split_torch_state_dict_into_shards(empty_state_dict, + filename_pattern=filename_pattern, + max_shard_size=max_shard_size) + else: + from collections import namedtuple + StateDictSplit = namedtuple("StateDictSplit", ["is_sharded", "filename_to_tensors"]) + state_dict_split = StateDictSplit(is_sharded=False, + filename_to_tensors={weights_name: list(state_dict.keys())}) + + # Save the model by shard + os.makedirs(output_dir, exist_ok=True) + filename_to_tensors = state_dict_split.filename_to_tensors.items() + for shard_file, tensors in tqdm(filename_to_tensors, desc="Saving checkpoint shards"): + shard_state_dict = {tensor_name: state_dict[tensor_name] for tensor_name in tensors} + shard_state_dict = to_torch_tensor(shard_state_dict) + output_path = os.path.join(output_dir, shard_file) + if safe_serialization: + save_file(shard_state_dict, output_path, metadata={"format": "pt"}) + else: + torch.save(shard_state_dict, output_path) + # release the memory of current shard + for tensor_name in list(shard_state_dict.keys()): + del state_dict[tensor_name] + del shard_state_dict[tensor_name] + del shard_state_dict + gc.collect() + + # Save index if sharded + if state_dict_split.is_sharded: + index = { + "metadata": state_dict_split.metadata, + "weight_map": state_dict_split.tensor_to_filename, + } + save_index_file = "model.safetensors.index.json" if safe_serialization else "pytorch_model.bin.index.json" + save_index_file = os.path.join(output_dir, save_index_file) + with open(save_index_file, "w", encoding="utf-8") as f: + content = json.dumps(index, indent=2, sort_keys=True) + "\n" + f.write(content) + + +def load_state_dict_from_zero_checkpoint(model, checkpoint_dir, tag=None): + """ + 1. Put the provided model to cpu + 2. Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict`` + 3. Load it into the provided model + + Args: + - ``model``: the model object to update + - ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``) + - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14`` + + Returns: + - ``model`: modified model + + Make sure you have plenty of CPU memory available before you call this function. If you don't + have enough use the ``zero_to_fp32.py`` utility to do the conversion. You will find it + conveniently placed for you in the checkpoint folder. + + A typical usage might be :: + + from deepspeed.utils.zero_to_fp32 import load_state_dict_from_zero_checkpoint + model = load_state_dict_from_zero_checkpoint(trainer.model, checkpoint_dir) + # submit to model hub or save the model to share with others + + Note, that once this was run, the ``model`` will no longer be usable in the deepspeed context + of the same application. i.e. you will need to re-initialize the deepspeed engine, since + ``model.load_state_dict(state_dict)`` will remove all the deepspeed magic from it. + + """ + logger.info(f"Extracting fp32 weights") + state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag) + + logger.info(f"Overwriting model with fp32 weights") + model = model.cpu() + model.load_state_dict(state_dict, strict=False) + + return model + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("checkpoint_dir", + type=str, + help="path to the desired checkpoint folder, e.g., path/checkpoint-12") + parser.add_argument("output_dir", + type=str, + help="directory to the pytorch fp32 state_dict output files" + "(e.g. path/checkpoint-12-output/)") + parser.add_argument( + "--max_shard_size", + type=str, + default="5GB", + help="The maximum size for a checkpoint before being sharded. Checkpoints shard will then be each of size" + "lower than this size. If expressed as a string, needs to be digits followed by a unit (like `5MB`" + "We default it to 5GB in order for models to be able to run easily on free-tier google colab instances" + "without CPU OOM issues.") + parser.add_argument( + "--safe_serialization", + default=False, + action='store_true', + help="Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).") + parser.add_argument("-t", + "--tag", + type=str, + default=None, + help="checkpoint tag used as a unique identifier for checkpoint. e.g., global_step1") + parser.add_argument("--exclude_frozen_parameters", action='store_true', help="exclude frozen parameters") + parser.add_argument("-d", "--debug", action='store_true', help="enable debug") + args = parser.parse_args() + + debug = args.debug + + convert_zero_checkpoint_to_fp32_state_dict(args.checkpoint_dir, + args.output_dir, + max_shard_size=args.max_shard_size, + safe_serialization=args.safe_serialization, + tag=args.tag, + exclude_frozen_parameters=args.exclude_frozen_parameters)