diff --git a/.gitattributes b/.gitattributes
index a6344aac8c09253b3b630fb776ae94478aa0275b..30caea95a45813cb825905d0cd54f196d67c98df 100644
--- a/.gitattributes
+++ b/.gitattributes
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
*.zip filter=lfs diff=lfs merge=lfs -text
*.zst filter=lfs diff=lfs merge=lfs -text
*tfevents* filter=lfs diff=lfs merge=lfs -text
+best/tokenizer.json filter=lfs diff=lfs merge=lfs -text
+tokenizer.json filter=lfs diff=lfs merge=lfs -text
diff --git a/README.md b/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..bcf413f06e2b28bb06caaf150d32365c94fb6fa5
--- /dev/null
+++ b/README.md
@@ -0,0 +1,57 @@
+---
+library_name: transformers
+model_name: Qwen-sft
+tags:
+- generated_from_trainer
+- trl
+- sft
+licence: license
+---
+
+# Model Card for Qwen-sft
+
+This model is a fine-tuned version of [None](https://huggingface.co/None).
+It has been trained using [TRL](https://github.com/huggingface/trl).
+
+## Quick start
+
+```python
+from transformers import pipeline
+
+question = "If you had a time machine, but could only go to the past or the future once and never return, which would you choose and why?"
+generator = pipeline("text-generation", model="Alphatao/Qwen-sft", device="cuda")
+output = generator([{"role": "user", "content": question}], max_new_tokens=128, return_full_text=False)[0]
+print(output["generated_text"])
+```
+
+## Training procedure
+
+[
](https://wandb.ai/alphatao-alphatao/Gradients-On-Demand/runs/igefnzy4)
+
+
+This model was trained with SFT.
+
+### Framework versions
+
+- TRL: 0.20.0
+- Transformers: 4.54.1
+- Pytorch: 2.7.1
+- Datasets: 4.0.0
+- Tokenizers: 0.21.4
+
+## Citations
+
+
+
+Cite TRL as:
+
+```bibtex
+@misc{vonwerra2022trl,
+ title = {{TRL: Transformer Reinforcement Learning}},
+ author = {Leandro von Werra and Younes Belkada and Lewis Tunstall and Edward Beeching and Tristan Thrush and Nathan Lambert and Shengyi Huang and Kashif Rasul and Quentin Gallou{\'e}dec},
+ year = 2020,
+ journal = {GitHub repository},
+ publisher = {GitHub},
+ howpublished = {\url{https://github.com/huggingface/trl}}
+}
+```
\ No newline at end of file
diff --git a/best/chat_template.jinja b/best/chat_template.jinja
new file mode 100644
index 0000000000000000000000000000000000000000..7641578becb325296855fbde0f59778cae094eee
--- /dev/null
+++ b/best/chat_template.jinja
@@ -0,0 +1,14 @@
+{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% set ns = namespace(is_first=false, is_tool=false, is_output_first=true, system_prompt='', is_first_sp=true, is_last_user=false) %}{%- for message in messages %}{%- if message['role'] == 'system' %}{%- if ns.is_first_sp %}{% set ns.system_prompt = ns.system_prompt + message['content'] %}{% set ns.is_first_sp = false %}{%- else %}{% set ns.system_prompt = ns.system_prompt + '
+
+' + message['content'] %}{%- endif %}{%- endif %}{%- endfor %}{{ bos_token }}{{ ns.system_prompt }}{%- for message in messages %}{% set content = message['content'] %}{%- if message['role'] == 'user' %}{%- set ns.is_tool = false -%}{%- set ns.is_first = false -%}{%- set ns.is_last_user = true -%}{{'<|User|>' + content + '<|Assistant|>'}}{%- endif %}{%- if message['role'] == 'assistant' %}{% if '' in content %}{% set content = content.split('')[-1] %}{% endif %}{% endif %}{%- if message['role'] == 'assistant' and message['tool_calls'] is defined and message['tool_calls'] is not none %}{%- set ns.is_last_user = false -%}{%- if ns.is_tool %}{{'<|tool▁outputs▁end|>'}}{%- endif %}{%- set ns.is_first = false %}{%- set ns.is_tool = false -%}{%- set ns.is_output_first = true %}{%- for tool in message['tool_calls'] %}{%- if not ns.is_first %}{%- if content is none %}{{'<|tool▁calls▁begin|><|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '
+' + '```json' + '
+' + tool['function']['arguments'] + '
+' + '```' + '<|tool▁call▁end|>'}}{%- else %}{{content + '<|tool▁calls▁begin|><|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '
+' + '```json' + '
+' + tool['function']['arguments'] + '
+' + '```' + '<|tool▁call▁end|>'}}{%- endif %}{%- set ns.is_first = true -%}{%- else %}{{'
+' + '<|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '
+' + '```json' + '
+' + tool['function']['arguments'] + '
+' + '```' + '<|tool▁call▁end|>'}}{%- endif %}{%- endfor %}{{'<|tool▁calls▁end|><|end▁of▁sentence|>'}}{%- endif %}{%- if message['role'] == 'assistant' and (message['tool_calls'] is not defined or message['tool_calls'] is none)%}{%- set ns.is_last_user = false -%}{%- if ns.is_tool %}{{'<|tool▁outputs▁end|>' + content + '<|end▁of▁sentence|>'}}{%- set ns.is_tool = false -%}{%- else %}{{content + '<|end▁of▁sentence|>'}}{%- endif %}{%- endif %}{%- if message['role'] == 'tool' %}{%- set ns.is_last_user = false -%}{%- set ns.is_tool = true -%}{%- if ns.is_output_first %}{{'<|tool▁outputs▁begin|><|tool▁output▁begin|>' + content + '<|tool▁output▁end|>'}}{%- set ns.is_output_first = false %}{%- else %}{{'
+<|tool▁output▁begin|>' + content + '<|tool▁output▁end|>'}}{%- endif %}{%- endif %}{%- endfor -%}{% if ns.is_tool %}{{'<|tool▁outputs▁end|>'}}{% endif %}{% if add_generation_prompt and not ns.is_last_user and not ns.is_tool %}{{'<|Assistant|>'}}{% endif %}
\ No newline at end of file
diff --git a/best/config.json b/best/config.json
new file mode 100644
index 0000000000000000000000000000000000000000..0c1dc2823d465ee169afa942d160d838fcf054fe
--- /dev/null
+++ b/best/config.json
@@ -0,0 +1,73 @@
+{
+ "architectures": [
+ "Qwen3ForCausalLM"
+ ],
+ "attention_bias": false,
+ "attention_dropout": 0.0,
+ "bos_token_id": 151643,
+ "eos_token_id": 151645,
+ "head_dim": 128,
+ "hidden_act": "silu",
+ "hidden_size": 4096,
+ "initializer_range": 0.02,
+ "intermediate_size": 12288,
+ "layer_types": [
+ "full_attention",
+ "full_attention",
+ "full_attention",
+ "full_attention",
+ "full_attention",
+ "full_attention",
+ "full_attention",
+ "full_attention",
+ "full_attention",
+ "full_attention",
+ "full_attention",
+ "full_attention",
+ "full_attention",
+ "full_attention",
+ "full_attention",
+ "full_attention",
+ "full_attention",
+ "full_attention",
+ "full_attention",
+ "full_attention",
+ "full_attention",
+ "full_attention",
+ "full_attention",
+ "full_attention",
+ "full_attention",
+ "full_attention",
+ "full_attention",
+ "full_attention",
+ "full_attention",
+ "full_attention",
+ "full_attention",
+ "full_attention",
+ "full_attention",
+ "full_attention",
+ "full_attention",
+ "full_attention"
+ ],
+ "max_position_embeddings": 131072,
+ "max_window_layers": 36,
+ "model_type": "qwen3",
+ "num_attention_heads": 32,
+ "num_hidden_layers": 36,
+ "num_key_value_heads": 8,
+ "rms_norm_eps": 1e-06,
+ "rope_scaling": {
+ "attn_factor": 0.8782488562869419,
+ "factor": 4.0,
+ "original_max_position_embeddings": 32768,
+ "rope_type": "yarn"
+ },
+ "rope_theta": 1000000,
+ "sliding_window": null,
+ "tie_word_embeddings": false,
+ "torch_dtype": "bfloat16",
+ "transformers_version": "4.54.1",
+ "use_cache": true,
+ "use_sliding_window": false,
+ "vocab_size": 151936
+}
diff --git a/best/generation_config.json b/best/generation_config.json
new file mode 100644
index 0000000000000000000000000000000000000000..ec4b7b20a7fce4dfc594eb138f56f2445d24c138
--- /dev/null
+++ b/best/generation_config.json
@@ -0,0 +1,6 @@
+{
+ "_from_model_config": true,
+ "bos_token_id": 151643,
+ "eos_token_id": 151645,
+ "transformers_version": "4.54.1"
+}
diff --git a/best/global_step50/bf16_zero_pp_rank_0_mp_rank_00_optim_states.pt b/best/global_step50/bf16_zero_pp_rank_0_mp_rank_00_optim_states.pt
new file mode 100644
index 0000000000000000000000000000000000000000..de7a58d08ba83991a1c1764385bd47b86481f1b9
--- /dev/null
+++ b/best/global_step50/bf16_zero_pp_rank_0_mp_rank_00_optim_states.pt
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:c6a61f5f4449b8684210094097e113ee5840d4baa1d278667c9dff2f8e579723
+size 12286638307
diff --git a/best/global_step50/bf16_zero_pp_rank_1_mp_rank_00_optim_states.pt b/best/global_step50/bf16_zero_pp_rank_1_mp_rank_00_optim_states.pt
new file mode 100644
index 0000000000000000000000000000000000000000..9ab96e0b22b8f717bdd25faa3a4fc427b4fd99a3
--- /dev/null
+++ b/best/global_step50/bf16_zero_pp_rank_1_mp_rank_00_optim_states.pt
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:7452a55b25f22ded6c261807ddcdefff88b57f60876ed8f6583c539026459e18
+size 12286638307
diff --git a/best/global_step50/bf16_zero_pp_rank_2_mp_rank_00_optim_states.pt b/best/global_step50/bf16_zero_pp_rank_2_mp_rank_00_optim_states.pt
new file mode 100644
index 0000000000000000000000000000000000000000..873deb72ababfb2074af5ef4c106341583c9cbd0
--- /dev/null
+++ b/best/global_step50/bf16_zero_pp_rank_2_mp_rank_00_optim_states.pt
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:c98855f45b15c81042921e19073a6b0294d42b297a105d04ccd31f7b6cc2efdc
+size 12286638307
diff --git a/best/global_step50/bf16_zero_pp_rank_3_mp_rank_00_optim_states.pt b/best/global_step50/bf16_zero_pp_rank_3_mp_rank_00_optim_states.pt
new file mode 100644
index 0000000000000000000000000000000000000000..f2fcd0dc5260a50e8c5144275937b3cabfb59b4c
--- /dev/null
+++ b/best/global_step50/bf16_zero_pp_rank_3_mp_rank_00_optim_states.pt
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:9c2d0ff70f9cf2cccc86d3b13d5f3cb01f37ae7b5e32bc5902446c9cc62e48c9
+size 12286638307
diff --git a/best/global_step50/bf16_zero_pp_rank_4_mp_rank_00_optim_states.pt b/best/global_step50/bf16_zero_pp_rank_4_mp_rank_00_optim_states.pt
new file mode 100644
index 0000000000000000000000000000000000000000..e31c95614365fc8ff665f1a8183fcb8d6da51d68
--- /dev/null
+++ b/best/global_step50/bf16_zero_pp_rank_4_mp_rank_00_optim_states.pt
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:d6c115bb3b3fa983a999017b5d09162367227eaf19eae5e97ac309f16c058614
+size 12286638307
diff --git a/best/global_step50/bf16_zero_pp_rank_5_mp_rank_00_optim_states.pt b/best/global_step50/bf16_zero_pp_rank_5_mp_rank_00_optim_states.pt
new file mode 100644
index 0000000000000000000000000000000000000000..be78f26d0df03ed8af8043a53473df738e16e2ea
--- /dev/null
+++ b/best/global_step50/bf16_zero_pp_rank_5_mp_rank_00_optim_states.pt
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:2c58d9150c29e2bd16e5a3ffde712bd4d813aa60a75245f7353fd7817713e60c
+size 12286638307
diff --git a/best/global_step50/bf16_zero_pp_rank_6_mp_rank_00_optim_states.pt b/best/global_step50/bf16_zero_pp_rank_6_mp_rank_00_optim_states.pt
new file mode 100644
index 0000000000000000000000000000000000000000..b0724118583b2029d9ed4ec691a69f271d47dea7
--- /dev/null
+++ b/best/global_step50/bf16_zero_pp_rank_6_mp_rank_00_optim_states.pt
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:0d84045f90e183ed30c9366eb3488952d709b569bca72770b6387fa717548981
+size 12286638307
diff --git a/best/global_step50/bf16_zero_pp_rank_7_mp_rank_00_optim_states.pt b/best/global_step50/bf16_zero_pp_rank_7_mp_rank_00_optim_states.pt
new file mode 100644
index 0000000000000000000000000000000000000000..07602dbed2b09e82b32915d98a64b034e8853ed2
--- /dev/null
+++ b/best/global_step50/bf16_zero_pp_rank_7_mp_rank_00_optim_states.pt
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:285677a6afd776084e36816daf44448c6432d8f5ceb459c8e824b9cc66ee3f93
+size 12286638307
diff --git a/best/global_step50/zero_pp_rank_0_mp_rank_00_model_states.pt b/best/global_step50/zero_pp_rank_0_mp_rank_00_model_states.pt
new file mode 100644
index 0000000000000000000000000000000000000000..4f39a2afb8409d18cd9b9faf99c29c9d0247ef80
--- /dev/null
+++ b/best/global_step50/zero_pp_rank_0_mp_rank_00_model_states.pt
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:855f528cf0bab0d11da820f9e79cd4b5d15a5543b89de20def728410248138ed
+size 206444
diff --git a/best/global_step50/zero_pp_rank_1_mp_rank_00_model_states.pt b/best/global_step50/zero_pp_rank_1_mp_rank_00_model_states.pt
new file mode 100644
index 0000000000000000000000000000000000000000..968716667206253a977eeac632b5a9ea4da15541
--- /dev/null
+++ b/best/global_step50/zero_pp_rank_1_mp_rank_00_model_states.pt
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:455cf17a101b529d0d1996e16d033bae6206b3a723d9092bc3086d3b4b74f1e3
+size 206444
diff --git a/best/global_step50/zero_pp_rank_2_mp_rank_00_model_states.pt b/best/global_step50/zero_pp_rank_2_mp_rank_00_model_states.pt
new file mode 100644
index 0000000000000000000000000000000000000000..14b7e620752a941e3efeae52bc9f4b12d7767f23
--- /dev/null
+++ b/best/global_step50/zero_pp_rank_2_mp_rank_00_model_states.pt
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:a92ac45256f9a630ce36461320b438f72309ff28ee8911eb62c21b5ee1ec105a
+size 206444
diff --git a/best/global_step50/zero_pp_rank_3_mp_rank_00_model_states.pt b/best/global_step50/zero_pp_rank_3_mp_rank_00_model_states.pt
new file mode 100644
index 0000000000000000000000000000000000000000..878c342108481839e0dab7a45818b2070d900164
--- /dev/null
+++ b/best/global_step50/zero_pp_rank_3_mp_rank_00_model_states.pt
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:b87bfd3b93bf7fec20ce6c4b775d763a2d7622f7129dbb82a2c92fe6815e63da
+size 206444
diff --git a/best/global_step50/zero_pp_rank_4_mp_rank_00_model_states.pt b/best/global_step50/zero_pp_rank_4_mp_rank_00_model_states.pt
new file mode 100644
index 0000000000000000000000000000000000000000..a5aadd473ab9fd1cbc878dfb5131b64805742287
--- /dev/null
+++ b/best/global_step50/zero_pp_rank_4_mp_rank_00_model_states.pt
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ede2f27e0e87806776ed677b520ba54b19ae7a47b086560e1f9136559fda8323
+size 206444
diff --git a/best/global_step50/zero_pp_rank_5_mp_rank_00_model_states.pt b/best/global_step50/zero_pp_rank_5_mp_rank_00_model_states.pt
new file mode 100644
index 0000000000000000000000000000000000000000..bf4baa61269ad7dd2f496b95dfbf898d0ee54f87
--- /dev/null
+++ b/best/global_step50/zero_pp_rank_5_mp_rank_00_model_states.pt
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:fdc129e37acb075cbff9d3a13b72f497750b08a5dce4bfb09a2d120ae7d657cf
+size 206444
diff --git a/best/global_step50/zero_pp_rank_6_mp_rank_00_model_states.pt b/best/global_step50/zero_pp_rank_6_mp_rank_00_model_states.pt
new file mode 100644
index 0000000000000000000000000000000000000000..2febdfcd48af5d7bb0c559b036c5f9b213c0c356
--- /dev/null
+++ b/best/global_step50/zero_pp_rank_6_mp_rank_00_model_states.pt
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:f55a26882e198d70abf070567520679fdc7abc58a8a78c4cf0d0038aec3c0924
+size 206444
diff --git a/best/global_step50/zero_pp_rank_7_mp_rank_00_model_states.pt b/best/global_step50/zero_pp_rank_7_mp_rank_00_model_states.pt
new file mode 100644
index 0000000000000000000000000000000000000000..5c37cb2d40bbd3bd1ddc9636a08ee8b2c224e3ba
--- /dev/null
+++ b/best/global_step50/zero_pp_rank_7_mp_rank_00_model_states.pt
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:52ca437d03bc1c7d2e1496207bd403dab3e25b4f4bf53d039885675a6fe5e694
+size 206444
diff --git a/best/latest b/best/latest
new file mode 100644
index 0000000000000000000000000000000000000000..9b4dc801e3fb152ef5c0ee60d309c705a9b01564
--- /dev/null
+++ b/best/latest
@@ -0,0 +1 @@
+global_step50
\ No newline at end of file
diff --git a/best/special_tokens_map.json b/best/special_tokens_map.json
new file mode 100644
index 0000000000000000000000000000000000000000..1d385d62cf08bca35254547902b792c243656ec1
--- /dev/null
+++ b/best/special_tokens_map.json
@@ -0,0 +1,23 @@
+{
+ "bos_token": {
+ "content": "<|begin▁of▁sentence|>",
+ "lstrip": false,
+ "normalized": false,
+ "rstrip": false,
+ "single_word": false
+ },
+ "eos_token": {
+ "content": "<|end▁of▁sentence|>",
+ "lstrip": false,
+ "normalized": false,
+ "rstrip": false,
+ "single_word": false
+ },
+ "pad_token": {
+ "content": "<|end▁of▁sentence|>",
+ "lstrip": false,
+ "normalized": false,
+ "rstrip": false,
+ "single_word": false
+ }
+}
diff --git a/best/tokenizer.json b/best/tokenizer.json
new file mode 100644
index 0000000000000000000000000000000000000000..d768bd0f05cc9821f74a87e2ec4e1229a09fa389
--- /dev/null
+++ b/best/tokenizer.json
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:93d5fd6d2f8cf1172ac86cf982e2b88fa6732366b44dc1a32349379a54a6a044
+size 11423346
diff --git a/best/tokenizer_config.json b/best/tokenizer_config.json
new file mode 100644
index 0000000000000000000000000000000000000000..c4657dc1d512391538fdb16dd905945ad376062f
--- /dev/null
+++ b/best/tokenizer_config.json
@@ -0,0 +1,242 @@
+{
+ "add_bos_token": false,
+ "add_eos_token": false,
+ "add_prefix_space": null,
+ "added_tokens_decoder": {
+ "151643": {
+ "content": "<|begin▁of▁sentence|>",
+ "lstrip": false,
+ "normalized": false,
+ "rstrip": false,
+ "single_word": false,
+ "special": true
+ },
+ "151644": {
+ "content": "<|im_start|>",
+ "lstrip": false,
+ "normalized": false,
+ "rstrip": false,
+ "single_word": false,
+ "special": true
+ },
+ "151645": {
+ "content": "<|end▁of▁sentence|>",
+ "lstrip": false,
+ "normalized": false,
+ "rstrip": false,
+ "single_word": false,
+ "special": true
+ },
+ "151646": {
+ "content": "<|object_ref_start|>",
+ "lstrip": false,
+ "normalized": false,
+ "rstrip": false,
+ "single_word": false,
+ "special": true
+ },
+ "151647": {
+ "content": "<|object_ref_end|>",
+ "lstrip": false,
+ "normalized": false,
+ "rstrip": false,
+ "single_word": false,
+ "special": true
+ },
+ "151648": {
+ "content": "<|box_start|>",
+ "lstrip": false,
+ "normalized": false,
+ "rstrip": false,
+ "single_word": false,
+ "special": true
+ },
+ "151649": {
+ "content": "<|box_end|>",
+ "lstrip": false,
+ "normalized": false,
+ "rstrip": false,
+ "single_word": false,
+ "special": true
+ },
+ "151650": {
+ "content": "<|quad_start|>",
+ "lstrip": false,
+ "normalized": false,
+ "rstrip": false,
+ "single_word": false,
+ "special": true
+ },
+ "151651": {
+ "content": "<|quad_end|>",
+ "lstrip": false,
+ "normalized": false,
+ "rstrip": false,
+ "single_word": false,
+ "special": true
+ },
+ "151652": {
+ "content": "<|vision_start|>",
+ "lstrip": false,
+ "normalized": false,
+ "rstrip": false,
+ "single_word": false,
+ "special": true
+ },
+ "151653": {
+ "content": "<|vision_end|>",
+ "lstrip": false,
+ "normalized": false,
+ "rstrip": false,
+ "single_word": false,
+ "special": true
+ },
+ "151654": {
+ "content": "<|vision_pad|>",
+ "lstrip": false,
+ "normalized": false,
+ "rstrip": false,
+ "single_word": false,
+ "special": true
+ },
+ "151655": {
+ "content": "<|image_pad|>",
+ "lstrip": false,
+ "normalized": false,
+ "rstrip": false,
+ "single_word": false,
+ "special": true
+ },
+ "151656": {
+ "content": "<|video_pad|>",
+ "lstrip": false,
+ "normalized": false,
+ "rstrip": false,
+ "single_word": false,
+ "special": true
+ },
+ "151657": {
+ "content": "",
+ "lstrip": false,
+ "normalized": false,
+ "rstrip": false,
+ "single_word": false,
+ "special": false
+ },
+ "151658": {
+ "content": "",
+ "lstrip": false,
+ "normalized": false,
+ "rstrip": false,
+ "single_word": false,
+ "special": false
+ },
+ "151659": {
+ "content": "<|fim_prefix|>",
+ "lstrip": false,
+ "normalized": false,
+ "rstrip": false,
+ "single_word": false,
+ "special": false
+ },
+ "151660": {
+ "content": "<|fim_middle|>",
+ "lstrip": false,
+ "normalized": false,
+ "rstrip": false,
+ "single_word": false,
+ "special": false
+ },
+ "151661": {
+ "content": "<|fim_suffix|>",
+ "lstrip": false,
+ "normalized": false,
+ "rstrip": false,
+ "single_word": false,
+ "special": false
+ },
+ "151662": {
+ "content": "<|fim_pad|>",
+ "lstrip": false,
+ "normalized": false,
+ "rstrip": false,
+ "single_word": false,
+ "special": false
+ },
+ "151663": {
+ "content": "<|repo_name|>",
+ "lstrip": false,
+ "normalized": false,
+ "rstrip": false,
+ "single_word": false,
+ "special": false
+ },
+ "151664": {
+ "content": "<|file_sep|>",
+ "lstrip": false,
+ "normalized": false,
+ "rstrip": false,
+ "single_word": false,
+ "special": false
+ },
+ "151665": {
+ "content": "",
+ "lstrip": false,
+ "normalized": false,
+ "rstrip": false,
+ "single_word": false,
+ "special": false
+ },
+ "151666": {
+ "content": "",
+ "lstrip": false,
+ "normalized": false,
+ "rstrip": false,
+ "single_word": false,
+ "special": false
+ },
+ "151667": {
+ "content": "",
+ "lstrip": false,
+ "normalized": false,
+ "rstrip": false,
+ "single_word": false,
+ "special": false
+ },
+ "151668": {
+ "content": "",
+ "lstrip": false,
+ "normalized": false,
+ "rstrip": false,
+ "single_word": false,
+ "special": false
+ },
+ "151669": {
+ "content": "<|User|>",
+ "lstrip": false,
+ "normalized": false,
+ "rstrip": false,
+ "single_word": false,
+ "special": false
+ },
+ "151670": {
+ "content": "<|Assistant|>",
+ "lstrip": false,
+ "normalized": false,
+ "rstrip": false,
+ "single_word": false,
+ "special": false
+ }
+ },
+ "bos_token": "<|begin▁of▁sentence|>",
+ "clean_up_tokenization_spaces": false,
+ "eos_token": "<|end▁of▁sentence|>",
+ "extra_special_tokens": {},
+ "legacy": true,
+ "model_max_length": 131072,
+ "pad_token": "<|end▁of▁sentence|>",
+ "sp_model_kwargs": {},
+ "tokenizer_class": "LlamaTokenizerFast",
+ "unk_token": null,
+ "use_default_system_prompt": false
+}
diff --git a/best/training_args.bin b/best/training_args.bin
new file mode 100644
index 0000000000000000000000000000000000000000..a275320c73f948be3ab0a894560d261ebe08c98a
--- /dev/null
+++ b/best/training_args.bin
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:840c9e203e24a719dec7c5250c5553e38d741952e631bfb2dfe7bddec9d6e991
+size 7633
diff --git a/best/zero_to_fp32.py b/best/zero_to_fp32.py
new file mode 100644
index 0000000000000000000000000000000000000000..0e759146cadd92ddfefab3680146c2bd6a2b5c04
--- /dev/null
+++ b/best/zero_to_fp32.py
@@ -0,0 +1,760 @@
+#!/usr/bin/env python
+
+# Copyright (c) Microsoft Corporation.
+# SPDX-License-Identifier: Apache-2.0
+
+# DeepSpeed Team
+
+# This script extracts fp32 consolidated weights from a zero 1, 2 and 3 DeepSpeed checkpoints. It gets
+# copied into the top level checkpoint dir, so the user can easily do the conversion at any point in
+# the future. Once extracted, the weights don't require DeepSpeed and can be used in any
+# application.
+#
+# example:
+# python zero_to_fp32.py . output_dir/
+# or
+# python zero_to_fp32.py . output_dir/ --safe_serialization
+
+import argparse
+import torch
+import glob
+import math
+import os
+import re
+import gc
+import json
+import numpy as np
+from tqdm import tqdm
+from collections import OrderedDict
+from dataclasses import dataclass
+
+# while this script doesn't use deepspeed to recover data, since the checkpoints are pickled with
+# DeepSpeed data structures it has to be available in the current python environment.
+from deepspeed.utils import logger
+from deepspeed.checkpoint.constants import (DS_VERSION, OPTIMIZER_STATE_DICT, SINGLE_PARTITION_OF_FP32_GROUPS,
+ FP32_FLAT_GROUPS, ZERO_STAGE, PARTITION_COUNT, PARAM_SHAPES, BUFFER_NAMES,
+ FROZEN_PARAM_SHAPES, FROZEN_PARAM_FRAGMENTS)
+
+
+@dataclass
+class zero_model_state:
+ buffers: dict()
+ param_shapes: dict()
+ shared_params: list
+ ds_version: int
+ frozen_param_shapes: dict()
+ frozen_param_fragments: dict()
+
+
+debug = 0
+
+# load to cpu
+device = torch.device('cpu')
+
+
+def atoi(text):
+ return int(text) if text.isdigit() else text
+
+
+def natural_keys(text):
+ '''
+ alist.sort(key=natural_keys) sorts in human order
+ http://nedbatchelder.com/blog/200712/human_sorting.html
+ (See Toothy's implementation in the comments)
+ '''
+ return [atoi(c) for c in re.split(r'(\d+)', text)]
+
+
+def get_model_state_file(checkpoint_dir, zero_stage):
+ if not os.path.isdir(checkpoint_dir):
+ raise FileNotFoundError(f"Directory '{checkpoint_dir}' doesn't exist")
+
+ # there should be only one file
+ if zero_stage <= 2:
+ file = os.path.join(checkpoint_dir, "mp_rank_00_model_states.pt")
+ elif zero_stage == 3:
+ file = os.path.join(checkpoint_dir, "zero_pp_rank_0_mp_rank_00_model_states.pt")
+
+ if not os.path.exists(file):
+ raise FileNotFoundError(f"can't find model states file at '{file}'")
+
+ return file
+
+
+def get_checkpoint_files(checkpoint_dir, glob_pattern):
+ # XXX: need to test that this simple glob rule works for multi-node setup too
+ ckpt_files = sorted(glob.glob(os.path.join(checkpoint_dir, glob_pattern)), key=natural_keys)
+
+ if len(ckpt_files) == 0:
+ raise FileNotFoundError(f"can't find {glob_pattern} files in directory '{checkpoint_dir}'")
+
+ return ckpt_files
+
+
+def get_optim_files(checkpoint_dir):
+ return get_checkpoint_files(checkpoint_dir, "*_optim_states.pt")
+
+
+def get_model_state_files(checkpoint_dir):
+ return get_checkpoint_files(checkpoint_dir, "*_model_states.pt")
+
+
+def parse_model_states(files):
+ zero_model_states = []
+ for file in files:
+ state_dict = torch.load(file, map_location=device, weights_only=False)
+
+ if BUFFER_NAMES not in state_dict:
+ raise ValueError(f"{file} is not a model state checkpoint")
+ buffer_names = state_dict[BUFFER_NAMES]
+ if debug:
+ print("Found buffers:", buffer_names)
+
+ # recover just the buffers while restoring them to fp32 if they were saved in fp16
+ buffers = {k: v.float() for k, v in state_dict["module"].items() if k in buffer_names}
+ param_shapes = state_dict[PARAM_SHAPES]
+
+ # collect parameters that are included in param_shapes
+ param_names = []
+ for s in param_shapes:
+ for name in s.keys():
+ param_names.append(name)
+
+ # update with frozen parameters
+ frozen_param_shapes = state_dict.get(FROZEN_PARAM_SHAPES, None)
+ if frozen_param_shapes is not None:
+ if debug:
+ print(f"Found frozen_param_shapes: {frozen_param_shapes}")
+ param_names += list(frozen_param_shapes.keys())
+
+ # handle shared params
+ shared_params = [[k, v] for k, v in state_dict["shared_params"].items()]
+
+ ds_version = state_dict.get(DS_VERSION, None)
+
+ frozen_param_fragments = state_dict.get(FROZEN_PARAM_FRAGMENTS, None)
+
+ z_model_state = zero_model_state(buffers=buffers,
+ param_shapes=param_shapes,
+ shared_params=shared_params,
+ ds_version=ds_version,
+ frozen_param_shapes=frozen_param_shapes,
+ frozen_param_fragments=frozen_param_fragments)
+ zero_model_states.append(z_model_state)
+
+ return zero_model_states
+
+
+def parse_optim_states(files, ds_checkpoint_dir):
+ total_files = len(files)
+ state_dicts = []
+ for f in tqdm(files, desc='Loading checkpoint shards'):
+ state_dict = torch.load(f, map_location=device, mmap=True, weights_only=False)
+ # immediately discard the potentially huge 2 optimizer states as we only care for fp32 master weights
+ # and also handle the case where it was already removed by another helper script
+ state_dict["optimizer_state_dict"].pop("optimizer_state_dict", None)
+ state_dicts.append(state_dict)
+
+ if not ZERO_STAGE in state_dicts[0][OPTIMIZER_STATE_DICT]:
+ raise ValueError(f"{files[0]} is not a zero checkpoint")
+ zero_stage = state_dicts[0][OPTIMIZER_STATE_DICT][ZERO_STAGE]
+ world_size = state_dicts[0][OPTIMIZER_STATE_DICT][PARTITION_COUNT]
+
+ # For ZeRO-2 each param group can have different partition_count as data parallelism for expert
+ # parameters can be different from data parallelism for non-expert parameters. So we can just
+ # use the max of the partition_count to get the dp world_size.
+
+ if type(world_size) is list:
+ world_size = max(world_size)
+
+ if world_size != total_files:
+ raise ValueError(
+ f"Expected {world_size} of '*_optim_states.pt' under '{ds_checkpoint_dir}' but found {total_files} files. "
+ "Possibly due to an overwrite of an old checkpoint, or a checkpoint didn't get saved by one or more processes."
+ )
+
+ # the groups are named differently in each stage
+ if zero_stage <= 2:
+ fp32_groups_key = SINGLE_PARTITION_OF_FP32_GROUPS
+ elif zero_stage == 3:
+ fp32_groups_key = FP32_FLAT_GROUPS
+ else:
+ raise ValueError(f"unknown zero stage {zero_stage}")
+
+ fp32_flat_groups = [state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key] for i in range(len(state_dicts))]
+ return zero_stage, world_size, fp32_flat_groups
+
+
+def _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir, exclude_frozen_parameters):
+ """
+ Returns fp32 state_dict reconstructed from ds checkpoint
+
+ Args:
+ - ``ds_checkpoint_dir``: path to the deepspeed checkpoint folder (where the optimizer files are)
+
+ """
+ print(f"Processing zero checkpoint '{ds_checkpoint_dir}'")
+
+ optim_files = get_optim_files(ds_checkpoint_dir)
+ zero_stage, world_size, fp32_flat_groups = parse_optim_states(optim_files, ds_checkpoint_dir)
+ print(f"Detected checkpoint of type zero stage {zero_stage}, world_size: {world_size}")
+
+ model_files = get_model_state_files(ds_checkpoint_dir)
+
+ zero_model_states = parse_model_states(model_files)
+ print(f'Parsing checkpoint created by deepspeed=={zero_model_states[0].ds_version}')
+
+ if zero_stage <= 2:
+ return _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states,
+ exclude_frozen_parameters)
+ elif zero_stage == 3:
+ return _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states,
+ exclude_frozen_parameters)
+
+
+def _zero2_merge_frozen_params(state_dict, zero_model_states):
+ if zero_model_states[0].frozen_param_shapes is None or len(zero_model_states[0].frozen_param_shapes) == 0:
+ return
+
+ frozen_param_shapes = zero_model_states[0].frozen_param_shapes
+ frozen_param_fragments = zero_model_states[0].frozen_param_fragments
+
+ if debug:
+ num_elem = sum(s.numel() for s in frozen_param_shapes.values())
+ print(f'rank 0: {FROZEN_PARAM_SHAPES}.numel = {num_elem}')
+
+ wanted_params = len(frozen_param_shapes)
+ wanted_numel = sum(s.numel() for s in frozen_param_shapes.values())
+ avail_numel = sum([p.numel() for p in frozen_param_fragments.values()])
+ print(f'Frozen params: Have {avail_numel} numels to process.')
+ print(f'Frozen params: Need {wanted_numel} numels in {wanted_params} params')
+
+ total_params = 0
+ total_numel = 0
+ for name, shape in frozen_param_shapes.items():
+ total_params += 1
+ unpartitioned_numel = shape.numel()
+ total_numel += unpartitioned_numel
+
+ state_dict[name] = frozen_param_fragments[name]
+
+ if debug:
+ print(f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} ")
+
+ print(f"Reconstructed Frozen fp32 state dict with {total_params} params {total_numel} elements")
+
+
+def _has_callable(obj, fn):
+ attr = getattr(obj, fn, None)
+ return callable(attr)
+
+
+def _zero2_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states):
+ param_shapes = zero_model_states[0].param_shapes
+
+ # Reconstruction protocol:
+ #
+ # XXX: document this
+
+ if debug:
+ for i in range(world_size):
+ for j in range(len(fp32_flat_groups[0])):
+ print(f"{FP32_FLAT_GROUPS}[{i}][{j}].shape={fp32_flat_groups[i][j].shape}")
+
+ # XXX: memory usage doubles here (zero2)
+ num_param_groups = len(fp32_flat_groups[0])
+ merged_single_partition_of_fp32_groups = []
+ for i in range(num_param_groups):
+ merged_partitions = [sd[i] for sd in fp32_flat_groups]
+ full_single_fp32_vector = torch.cat(merged_partitions, 0)
+ merged_single_partition_of_fp32_groups.append(full_single_fp32_vector)
+ avail_numel = sum(
+ [full_single_fp32_vector.numel() for full_single_fp32_vector in merged_single_partition_of_fp32_groups])
+
+ if debug:
+ wanted_params = sum([len(shapes) for shapes in param_shapes])
+ wanted_numel = sum([sum(shape.numel() for shape in shapes.values()) for shapes in param_shapes])
+ # not asserting if there is a mismatch due to possible padding
+ print(f"Have {avail_numel} numels to process.")
+ print(f"Need {wanted_numel} numels in {wanted_params} params.")
+
+ # params
+ # XXX: for huge models that can't fit into the host's RAM we will have to recode this to support
+ # out-of-core computing solution
+ total_numel = 0
+ total_params = 0
+ for shapes, full_single_fp32_vector in zip(param_shapes, merged_single_partition_of_fp32_groups):
+ offset = 0
+ avail_numel = full_single_fp32_vector.numel()
+ for name, shape in shapes.items():
+
+ unpartitioned_numel = shape.numel() if _has_callable(shape, 'numel') else math.prod(shape)
+ total_numel += unpartitioned_numel
+ total_params += 1
+
+ if debug:
+ print(f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} ")
+ state_dict[name] = full_single_fp32_vector.narrow(0, offset, unpartitioned_numel).view(shape)
+ offset += unpartitioned_numel
+
+ # Z2 started to align to 2*world_size to improve nccl performance. Therefore both offset and
+ # avail_numel can differ by anywhere between 0..2*world_size. Due to two unrelated complex
+ # paddings performed in the code it's almost impossible to predict the exact numbers w/o the
+ # live optimizer object, so we are checking that the numbers are within the right range
+ align_to = 2 * world_size
+
+ def zero2_align(x):
+ return align_to * math.ceil(x / align_to)
+
+ if debug:
+ print(f"original offset={offset}, avail_numel={avail_numel}")
+
+ offset = zero2_align(offset)
+ avail_numel = zero2_align(avail_numel)
+
+ if debug:
+ print(f"aligned offset={offset}, avail_numel={avail_numel}")
+
+ # Sanity check
+ if offset != avail_numel:
+ raise ValueError(f"consumed {offset} numels out of {avail_numel} - something is wrong")
+
+ print(f"Reconstructed fp32 state dict with {total_params} params {total_numel} elements")
+
+
+def _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states,
+ exclude_frozen_parameters):
+ state_dict = OrderedDict()
+
+ # buffers
+ buffers = zero_model_states[0].buffers
+ state_dict.update(buffers)
+ if debug:
+ print(f"added {len(buffers)} buffers")
+
+ if not exclude_frozen_parameters:
+ _zero2_merge_frozen_params(state_dict, zero_model_states)
+
+ _zero2_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states)
+
+ # recover shared parameters
+ for pair in zero_model_states[0].shared_params:
+ if pair[1] in state_dict:
+ state_dict[pair[0]] = state_dict[pair[1]]
+
+ return state_dict
+
+
+def zero3_partitioned_param_info(unpartitioned_numel, world_size):
+ remainder = unpartitioned_numel % world_size
+ padding_numel = (world_size - remainder) if remainder else 0
+ partitioned_numel = math.ceil(unpartitioned_numel / world_size)
+ return partitioned_numel, padding_numel
+
+
+def _zero3_merge_frozen_params(state_dict, world_size, zero_model_states):
+ if zero_model_states[0].frozen_param_shapes is None or len(zero_model_states[0].frozen_param_shapes) == 0:
+ return
+
+ if debug:
+ for i in range(world_size):
+ num_elem = sum(s.numel() for s in zero_model_states[i].frozen_param_fragments.values())
+ print(f'rank {i}: {FROZEN_PARAM_SHAPES}.numel = {num_elem}')
+
+ frozen_param_shapes = zero_model_states[0].frozen_param_shapes
+ wanted_params = len(frozen_param_shapes)
+ wanted_numel = sum(s.numel() for s in frozen_param_shapes.values())
+ avail_numel = sum([p.numel() for p in zero_model_states[0].frozen_param_fragments.values()]) * world_size
+ print(f'Frozen params: Have {avail_numel} numels to process.')
+ print(f'Frozen params: Need {wanted_numel} numels in {wanted_params} params')
+
+ total_params = 0
+ total_numel = 0
+ for name, shape in zero_model_states[0].frozen_param_shapes.items():
+ total_params += 1
+ unpartitioned_numel = shape.numel()
+ total_numel += unpartitioned_numel
+
+ param_frags = tuple(model_state.frozen_param_fragments[name] for model_state in zero_model_states)
+ state_dict[name] = torch.cat(param_frags, 0).narrow(0, 0, unpartitioned_numel).view(shape)
+
+ partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(unpartitioned_numel, world_size)
+
+ if debug:
+ print(
+ f"Frozen params: {total_params} {name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}"
+ )
+
+ print(f"Reconstructed Frozen fp32 state dict with {total_params} params {total_numel} elements")
+
+
+class GatheredTensor:
+ """
+ A pseudo tensor that collects partitioned weights.
+ It is more memory efficient when there are multiple groups.
+ """
+
+ def __init__(self, flat_groups, flat_groups_offset, offset, partitioned_numel, shape):
+ self.flat_groups = flat_groups
+ self.flat_groups_offset = flat_groups_offset
+ self.offset = offset
+ self.partitioned_numel = partitioned_numel
+ self.shape = shape
+ self.dtype = self.flat_groups[0][0].dtype
+
+ def contiguous(self):
+ """
+ Merge partitioned weights from flat_groups into a single tensor.
+ """
+ end_idx = self.offset + self.partitioned_numel
+ world_size = len(self.flat_groups)
+ pad_flat_param_chunks = []
+
+ for rank_i in range(world_size):
+ # for each rank, we need to collect weights from related group/groups
+ flat_groups_at_rank_i = self.flat_groups[rank_i]
+ start_group_id = None
+ end_group_id = None
+ for group_id in range(len(self.flat_groups_offset)):
+ if self.flat_groups_offset[group_id] <= self.offset < self.flat_groups_offset[group_id + 1]:
+ start_group_id = group_id
+ if self.flat_groups_offset[group_id] < end_idx <= self.flat_groups_offset[group_id + 1]:
+ end_group_id = group_id
+ break
+ # collect weights from related group/groups
+ for group_id in range(start_group_id, end_group_id + 1):
+ flat_tensor = flat_groups_at_rank_i[group_id]
+ start_offset = self.offset - self.flat_groups_offset[group_id]
+ end_offset = min(end_idx, self.flat_groups_offset[group_id + 1]) - self.flat_groups_offset[group_id]
+ pad_flat_param_chunks.append(flat_tensor[start_offset:end_offset])
+
+ # collect weights from all ranks
+ pad_flat_param = torch.cat(pad_flat_param_chunks, dim=0)
+ param = pad_flat_param[:self.shape.numel()].view(self.shape).contiguous()
+ return param
+
+
+def _zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states):
+ param_shapes = zero_model_states[0].param_shapes
+ avail_numel = sum([flat_group.numel() for flat_group in fp32_flat_groups[0]]) * world_size
+
+ # Reconstruction protocol: For zero3 we need to zip the partitions together at boundary of each
+ # param, re-consolidating each param, while dealing with padding if any
+
+ # merge list of dicts, preserving order
+ param_shapes = {k: v for d in param_shapes for k, v in d.items()}
+
+ if debug:
+ for i in range(world_size):
+ print(f"{FP32_FLAT_GROUPS}[{i}].shape={fp32_flat_groups[i].shape}")
+
+ wanted_params = len(param_shapes)
+ wanted_numel = sum(shape.numel() for shape in param_shapes.values())
+ # not asserting if there is a mismatch due to possible padding
+ avail_numel = fp32_flat_groups[0].numel() * world_size
+ print(f"Trainable params: Have {avail_numel} numels to process.")
+ print(f"Trainable params: Need {wanted_numel} numels in {wanted_params} params.")
+
+ # params
+ # XXX: for huge models that can't fit into the host's RAM we will have to recode this to support
+ # out-of-core computing solution
+ offset = 0
+ total_numel = 0
+ total_params = 0
+ flat_groups_offset = [0] + list(np.cumsum([flat_tensor.numel() for flat_tensor in fp32_flat_groups[0]]))
+ for name, shape in tqdm(param_shapes.items(), desc='Gathering sharded weights'):
+ unpartitioned_numel = shape.numel()
+ total_numel += unpartitioned_numel
+ total_params += 1
+ partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(unpartitioned_numel, world_size)
+
+ if debug:
+ print(
+ f"Trainable params: {total_params} {name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}"
+ )
+
+ # memory efficient tensor
+ tensor = GatheredTensor(fp32_flat_groups, flat_groups_offset, offset, partitioned_numel, shape)
+ state_dict[name] = tensor
+ offset += partitioned_numel
+
+ offset *= world_size
+
+ # Sanity check
+ if offset != avail_numel:
+ raise ValueError(f"consumed {offset} numels out of {avail_numel} - something is wrong")
+
+ print(f"Reconstructed Trainable fp32 state dict with {total_params} params {total_numel} elements")
+
+
+def _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states,
+ exclude_frozen_parameters):
+ state_dict = OrderedDict()
+
+ # buffers
+ buffers = zero_model_states[0].buffers
+ state_dict.update(buffers)
+ if debug:
+ print(f"added {len(buffers)} buffers")
+
+ if not exclude_frozen_parameters:
+ _zero3_merge_frozen_params(state_dict, world_size, zero_model_states)
+
+ _zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states)
+
+ # recover shared parameters
+ for pair in zero_model_states[0].shared_params:
+ if pair[1] in state_dict:
+ state_dict[pair[0]] = state_dict[pair[1]]
+
+ return state_dict
+
+
+def to_torch_tensor(state_dict, return_empty_tensor=False):
+ """
+ Convert state_dict of GatheredTensor to torch tensor
+ """
+ torch_state_dict = {}
+ converted_tensors = {}
+ for name, tensor in state_dict.items():
+ tensor_id = id(tensor)
+ if tensor_id in converted_tensors: # shared tensors
+ shared_tensor = torch_state_dict[converted_tensors[tensor_id]]
+ torch_state_dict[name] = shared_tensor
+ else:
+ converted_tensors[tensor_id] = name
+ if return_empty_tensor:
+ torch_state_dict[name] = torch.empty(tensor.shape, dtype=tensor.dtype)
+ else:
+ torch_state_dict[name] = tensor.contiguous()
+ return torch_state_dict
+
+
+def get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir,
+ tag=None,
+ exclude_frozen_parameters=False,
+ lazy_mode=False):
+ """
+ Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated state_dict that can be loaded with
+ ``load_state_dict()`` and used for training without DeepSpeed or shared with others, for example
+ via a model hub.
+
+ Args:
+ - ``checkpoint_dir``: path to the desired checkpoint folder
+ - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in 'latest' file. e.g., ``global_step14``
+ - ``exclude_frozen_parameters``: exclude frozen parameters
+ - ``lazy_mode``: get state_dict in lazy mode. It returns a dict of pesduo tensor instead of torch tensor, which is more memory efficient.
+ Convert the pesduo tensor to torch tensor by ``.contiguous()``
+
+ Returns:
+ - pytorch ``state_dict``
+
+ A typical usage might be ::
+
+ from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint
+ # do the training and checkpoint saving
+ state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir) # already on cpu
+ model = model.cpu() # move to cpu
+ model.load_state_dict(state_dict)
+ # submit to model hub or save the model to share with others
+
+ In this example the ``model`` will no longer be usable in the deepspeed context of the same
+ application. i.e. you will need to re-initialize the deepspeed engine, since
+ ``model.load_state_dict(state_dict)`` will remove all the deepspeed magic from it.
+
+ If you want it all done for you, use ``load_state_dict_from_zero_checkpoint`` instead.
+
+ Note: the above usage may not work if your application doesn't have sufficient free CPU memory.
+ You may need to use the offline approach using the ``zero_to_fp32.py`` script that is saved with
+ the checkpoint. Or you can load state_dict in lazy mode ::
+
+ from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint
+ state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, lazy_mode=True) # not on cpu
+ for name, lazy_tensor in state_dict.item():
+ tensor = lazy_tensor.contiguous() # to cpu
+ print(name, tensor)
+ # del tensor to release memory if it no longer in use
+ """
+ if tag is None:
+ latest_path = os.path.join(checkpoint_dir, 'latest')
+ if os.path.isfile(latest_path):
+ with open(latest_path, 'r') as fd:
+ tag = fd.read().strip()
+ else:
+ raise ValueError(f"Unable to find 'latest' file at {latest_path}")
+
+ ds_checkpoint_dir = os.path.join(checkpoint_dir, tag)
+
+ if not os.path.isdir(ds_checkpoint_dir):
+ raise FileNotFoundError(f"Directory '{ds_checkpoint_dir}' doesn't exist")
+
+ state_dict = _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir, exclude_frozen_parameters)
+ if lazy_mode:
+ return state_dict
+ else:
+ return to_torch_tensor(state_dict)
+
+
+def convert_zero_checkpoint_to_fp32_state_dict(checkpoint_dir,
+ output_dir,
+ max_shard_size="5GB",
+ safe_serialization=False,
+ tag=None,
+ exclude_frozen_parameters=False):
+ """
+ Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict`` file that can be
+ loaded with ``torch.load(file)`` + ``load_state_dict()`` and used for training without DeepSpeed.
+
+ Args:
+ - ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``)
+ - ``output_dir``: directory to the pytorch fp32 state_dict output files
+ - ``max_shard_size``: the maximum size for a checkpoint before being sharded, default value is 5GB
+ - ``safe_serialization``: whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).
+ - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14``
+ - ``exclude_frozen_parameters``: exclude frozen parameters
+ """
+
+ # Dependency pre-check
+ if safe_serialization:
+ try:
+ from safetensors.torch import save_file
+ except ImportError:
+ print('If you want to use `safe_serialization`, please `pip install safetensors`')
+ raise
+ if max_shard_size is not None:
+ try:
+ from huggingface_hub import split_torch_state_dict_into_shards
+ except ImportError:
+ print('If you want to use `max_shard_size`, please `pip install huggingface_hub`')
+ raise
+
+ # Convert zero checkpoint to state_dict
+ state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir,
+ tag,
+ exclude_frozen_parameters,
+ lazy_mode=True)
+
+ # Shard the model if it is too big.
+ weights_name = "model.safetensors" if safe_serialization else "pytorch_model.bin"
+ if max_shard_size is not None:
+ filename_pattern = weights_name.replace(".bin", "{suffix}.bin").replace(".safetensors", "{suffix}.safetensors")
+ # an memory-efficient approach for sharding
+ empty_state_dict = to_torch_tensor(state_dict, return_empty_tensor=True)
+ state_dict_split = split_torch_state_dict_into_shards(empty_state_dict,
+ filename_pattern=filename_pattern,
+ max_shard_size=max_shard_size)
+ else:
+ from collections import namedtuple
+ StateDictSplit = namedtuple("StateDictSplit", ["is_sharded", "filename_to_tensors"])
+ state_dict_split = StateDictSplit(is_sharded=False,
+ filename_to_tensors={weights_name: list(state_dict.keys())})
+
+ # Save the model by shard
+ os.makedirs(output_dir, exist_ok=True)
+ filename_to_tensors = state_dict_split.filename_to_tensors.items()
+ for shard_file, tensors in tqdm(filename_to_tensors, desc="Saving checkpoint shards"):
+ shard_state_dict = {tensor_name: state_dict[tensor_name] for tensor_name in tensors}
+ shard_state_dict = to_torch_tensor(shard_state_dict)
+ output_path = os.path.join(output_dir, shard_file)
+ if safe_serialization:
+ save_file(shard_state_dict, output_path, metadata={"format": "pt"})
+ else:
+ torch.save(shard_state_dict, output_path)
+ # release the memory of current shard
+ for tensor_name in list(shard_state_dict.keys()):
+ del state_dict[tensor_name]
+ del shard_state_dict[tensor_name]
+ del shard_state_dict
+ gc.collect()
+
+ # Save index if sharded
+ if state_dict_split.is_sharded:
+ index = {
+ "metadata": state_dict_split.metadata,
+ "weight_map": state_dict_split.tensor_to_filename,
+ }
+ save_index_file = "model.safetensors.index.json" if safe_serialization else "pytorch_model.bin.index.json"
+ save_index_file = os.path.join(output_dir, save_index_file)
+ with open(save_index_file, "w", encoding="utf-8") as f:
+ content = json.dumps(index, indent=2, sort_keys=True) + "\n"
+ f.write(content)
+
+
+def load_state_dict_from_zero_checkpoint(model, checkpoint_dir, tag=None):
+ """
+ 1. Put the provided model to cpu
+ 2. Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict``
+ 3. Load it into the provided model
+
+ Args:
+ - ``model``: the model object to update
+ - ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``)
+ - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14``
+
+ Returns:
+ - ``model`: modified model
+
+ Make sure you have plenty of CPU memory available before you call this function. If you don't
+ have enough use the ``zero_to_fp32.py`` utility to do the conversion. You will find it
+ conveniently placed for you in the checkpoint folder.
+
+ A typical usage might be ::
+
+ from deepspeed.utils.zero_to_fp32 import load_state_dict_from_zero_checkpoint
+ model = load_state_dict_from_zero_checkpoint(trainer.model, checkpoint_dir)
+ # submit to model hub or save the model to share with others
+
+ Note, that once this was run, the ``model`` will no longer be usable in the deepspeed context
+ of the same application. i.e. you will need to re-initialize the deepspeed engine, since
+ ``model.load_state_dict(state_dict)`` will remove all the deepspeed magic from it.
+
+ """
+ logger.info(f"Extracting fp32 weights")
+ state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag)
+
+ logger.info(f"Overwriting model with fp32 weights")
+ model = model.cpu()
+ model.load_state_dict(state_dict, strict=False)
+
+ return model
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("checkpoint_dir",
+ type=str,
+ help="path to the desired checkpoint folder, e.g., path/checkpoint-12")
+ parser.add_argument("output_dir",
+ type=str,
+ help="directory to the pytorch fp32 state_dict output files"
+ "(e.g. path/checkpoint-12-output/)")
+ parser.add_argument(
+ "--max_shard_size",
+ type=str,
+ default="5GB",
+ help="The maximum size for a checkpoint before being sharded. Checkpoints shard will then be each of size"
+ "lower than this size. If expressed as a string, needs to be digits followed by a unit (like `5MB`"
+ "We default it to 5GB in order for models to be able to run easily on free-tier google colab instances"
+ "without CPU OOM issues.")
+ parser.add_argument(
+ "--safe_serialization",
+ default=False,
+ action='store_true',
+ help="Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).")
+ parser.add_argument("-t",
+ "--tag",
+ type=str,
+ default=None,
+ help="checkpoint tag used as a unique identifier for checkpoint. e.g., global_step1")
+ parser.add_argument("--exclude_frozen_parameters", action='store_true', help="exclude frozen parameters")
+ parser.add_argument("-d", "--debug", action='store_true', help="enable debug")
+ args = parser.parse_args()
+
+ debug = args.debug
+
+ convert_zero_checkpoint_to_fp32_state_dict(args.checkpoint_dir,
+ args.output_dir,
+ max_shard_size=args.max_shard_size,
+ safe_serialization=args.safe_serialization,
+ tag=args.tag,
+ exclude_frozen_parameters=args.exclude_frozen_parameters)
diff --git a/chat_template.jinja b/chat_template.jinja
new file mode 100644
index 0000000000000000000000000000000000000000..7641578becb325296855fbde0f59778cae094eee
--- /dev/null
+++ b/chat_template.jinja
@@ -0,0 +1,14 @@
+{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% set ns = namespace(is_first=false, is_tool=false, is_output_first=true, system_prompt='', is_first_sp=true, is_last_user=false) %}{%- for message in messages %}{%- if message['role'] == 'system' %}{%- if ns.is_first_sp %}{% set ns.system_prompt = ns.system_prompt + message['content'] %}{% set ns.is_first_sp = false %}{%- else %}{% set ns.system_prompt = ns.system_prompt + '
+
+' + message['content'] %}{%- endif %}{%- endif %}{%- endfor %}{{ bos_token }}{{ ns.system_prompt }}{%- for message in messages %}{% set content = message['content'] %}{%- if message['role'] == 'user' %}{%- set ns.is_tool = false -%}{%- set ns.is_first = false -%}{%- set ns.is_last_user = true -%}{{'<|User|>' + content + '<|Assistant|>'}}{%- endif %}{%- if message['role'] == 'assistant' %}{% if '' in content %}{% set content = content.split('')[-1] %}{% endif %}{% endif %}{%- if message['role'] == 'assistant' and message['tool_calls'] is defined and message['tool_calls'] is not none %}{%- set ns.is_last_user = false -%}{%- if ns.is_tool %}{{'<|tool▁outputs▁end|>'}}{%- endif %}{%- set ns.is_first = false %}{%- set ns.is_tool = false -%}{%- set ns.is_output_first = true %}{%- for tool in message['tool_calls'] %}{%- if not ns.is_first %}{%- if content is none %}{{'<|tool▁calls▁begin|><|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '
+' + '```json' + '
+' + tool['function']['arguments'] + '
+' + '```' + '<|tool▁call▁end|>'}}{%- else %}{{content + '<|tool▁calls▁begin|><|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '
+' + '```json' + '
+' + tool['function']['arguments'] + '
+' + '```' + '<|tool▁call▁end|>'}}{%- endif %}{%- set ns.is_first = true -%}{%- else %}{{'
+' + '<|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '
+' + '```json' + '
+' + tool['function']['arguments'] + '
+' + '```' + '<|tool▁call▁end|>'}}{%- endif %}{%- endfor %}{{'<|tool▁calls▁end|><|end▁of▁sentence|>'}}{%- endif %}{%- if message['role'] == 'assistant' and (message['tool_calls'] is not defined or message['tool_calls'] is none)%}{%- set ns.is_last_user = false -%}{%- if ns.is_tool %}{{'<|tool▁outputs▁end|>' + content + '<|end▁of▁sentence|>'}}{%- set ns.is_tool = false -%}{%- else %}{{content + '<|end▁of▁sentence|>'}}{%- endif %}{%- endif %}{%- if message['role'] == 'tool' %}{%- set ns.is_last_user = false -%}{%- set ns.is_tool = true -%}{%- if ns.is_output_first %}{{'<|tool▁outputs▁begin|><|tool▁output▁begin|>' + content + '<|tool▁output▁end|>'}}{%- set ns.is_output_first = false %}{%- else %}{{'
+<|tool▁output▁begin|>' + content + '<|tool▁output▁end|>'}}{%- endif %}{%- endif %}{%- endfor -%}{% if ns.is_tool %}{{'<|tool▁outputs▁end|>'}}{% endif %}{% if add_generation_prompt and not ns.is_last_user and not ns.is_tool %}{{'<|Assistant|>'}}{% endif %}
\ No newline at end of file
diff --git a/config.json b/config.json
new file mode 100644
index 0000000000000000000000000000000000000000..0c1dc2823d465ee169afa942d160d838fcf054fe
--- /dev/null
+++ b/config.json
@@ -0,0 +1,73 @@
+{
+ "architectures": [
+ "Qwen3ForCausalLM"
+ ],
+ "attention_bias": false,
+ "attention_dropout": 0.0,
+ "bos_token_id": 151643,
+ "eos_token_id": 151645,
+ "head_dim": 128,
+ "hidden_act": "silu",
+ "hidden_size": 4096,
+ "initializer_range": 0.02,
+ "intermediate_size": 12288,
+ "layer_types": [
+ "full_attention",
+ "full_attention",
+ "full_attention",
+ "full_attention",
+ "full_attention",
+ "full_attention",
+ "full_attention",
+ "full_attention",
+ "full_attention",
+ "full_attention",
+ "full_attention",
+ "full_attention",
+ "full_attention",
+ "full_attention",
+ "full_attention",
+ "full_attention",
+ "full_attention",
+ "full_attention",
+ "full_attention",
+ "full_attention",
+ "full_attention",
+ "full_attention",
+ "full_attention",
+ "full_attention",
+ "full_attention",
+ "full_attention",
+ "full_attention",
+ "full_attention",
+ "full_attention",
+ "full_attention",
+ "full_attention",
+ "full_attention",
+ "full_attention",
+ "full_attention",
+ "full_attention",
+ "full_attention"
+ ],
+ "max_position_embeddings": 131072,
+ "max_window_layers": 36,
+ "model_type": "qwen3",
+ "num_attention_heads": 32,
+ "num_hidden_layers": 36,
+ "num_key_value_heads": 8,
+ "rms_norm_eps": 1e-06,
+ "rope_scaling": {
+ "attn_factor": 0.8782488562869419,
+ "factor": 4.0,
+ "original_max_position_embeddings": 32768,
+ "rope_type": "yarn"
+ },
+ "rope_theta": 1000000,
+ "sliding_window": null,
+ "tie_word_embeddings": false,
+ "torch_dtype": "bfloat16",
+ "transformers_version": "4.54.1",
+ "use_cache": true,
+ "use_sliding_window": false,
+ "vocab_size": 151936
+}
diff --git a/generation_config.json b/generation_config.json
new file mode 100644
index 0000000000000000000000000000000000000000..ec4b7b20a7fce4dfc594eb138f56f2445d24c138
--- /dev/null
+++ b/generation_config.json
@@ -0,0 +1,6 @@
+{
+ "_from_model_config": true,
+ "bos_token_id": 151643,
+ "eos_token_id": 151645,
+ "transformers_version": "4.54.1"
+}
diff --git a/global_step50/bf16_zero_pp_rank_0_mp_rank_00_optim_states.pt b/global_step50/bf16_zero_pp_rank_0_mp_rank_00_optim_states.pt
new file mode 100644
index 0000000000000000000000000000000000000000..de7a58d08ba83991a1c1764385bd47b86481f1b9
--- /dev/null
+++ b/global_step50/bf16_zero_pp_rank_0_mp_rank_00_optim_states.pt
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:c6a61f5f4449b8684210094097e113ee5840d4baa1d278667c9dff2f8e579723
+size 12286638307
diff --git a/global_step50/bf16_zero_pp_rank_1_mp_rank_00_optim_states.pt b/global_step50/bf16_zero_pp_rank_1_mp_rank_00_optim_states.pt
new file mode 100644
index 0000000000000000000000000000000000000000..9ab96e0b22b8f717bdd25faa3a4fc427b4fd99a3
--- /dev/null
+++ b/global_step50/bf16_zero_pp_rank_1_mp_rank_00_optim_states.pt
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:7452a55b25f22ded6c261807ddcdefff88b57f60876ed8f6583c539026459e18
+size 12286638307
diff --git a/global_step50/bf16_zero_pp_rank_2_mp_rank_00_optim_states.pt b/global_step50/bf16_zero_pp_rank_2_mp_rank_00_optim_states.pt
new file mode 100644
index 0000000000000000000000000000000000000000..873deb72ababfb2074af5ef4c106341583c9cbd0
--- /dev/null
+++ b/global_step50/bf16_zero_pp_rank_2_mp_rank_00_optim_states.pt
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:c98855f45b15c81042921e19073a6b0294d42b297a105d04ccd31f7b6cc2efdc
+size 12286638307
diff --git a/global_step50/bf16_zero_pp_rank_3_mp_rank_00_optim_states.pt b/global_step50/bf16_zero_pp_rank_3_mp_rank_00_optim_states.pt
new file mode 100644
index 0000000000000000000000000000000000000000..f2fcd0dc5260a50e8c5144275937b3cabfb59b4c
--- /dev/null
+++ b/global_step50/bf16_zero_pp_rank_3_mp_rank_00_optim_states.pt
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:9c2d0ff70f9cf2cccc86d3b13d5f3cb01f37ae7b5e32bc5902446c9cc62e48c9
+size 12286638307
diff --git a/global_step50/bf16_zero_pp_rank_4_mp_rank_00_optim_states.pt b/global_step50/bf16_zero_pp_rank_4_mp_rank_00_optim_states.pt
new file mode 100644
index 0000000000000000000000000000000000000000..e31c95614365fc8ff665f1a8183fcb8d6da51d68
--- /dev/null
+++ b/global_step50/bf16_zero_pp_rank_4_mp_rank_00_optim_states.pt
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:d6c115bb3b3fa983a999017b5d09162367227eaf19eae5e97ac309f16c058614
+size 12286638307
diff --git a/global_step50/bf16_zero_pp_rank_5_mp_rank_00_optim_states.pt b/global_step50/bf16_zero_pp_rank_5_mp_rank_00_optim_states.pt
new file mode 100644
index 0000000000000000000000000000000000000000..be78f26d0df03ed8af8043a53473df738e16e2ea
--- /dev/null
+++ b/global_step50/bf16_zero_pp_rank_5_mp_rank_00_optim_states.pt
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:2c58d9150c29e2bd16e5a3ffde712bd4d813aa60a75245f7353fd7817713e60c
+size 12286638307
diff --git a/global_step50/bf16_zero_pp_rank_6_mp_rank_00_optim_states.pt b/global_step50/bf16_zero_pp_rank_6_mp_rank_00_optim_states.pt
new file mode 100644
index 0000000000000000000000000000000000000000..b0724118583b2029d9ed4ec691a69f271d47dea7
--- /dev/null
+++ b/global_step50/bf16_zero_pp_rank_6_mp_rank_00_optim_states.pt
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:0d84045f90e183ed30c9366eb3488952d709b569bca72770b6387fa717548981
+size 12286638307
diff --git a/global_step50/bf16_zero_pp_rank_7_mp_rank_00_optim_states.pt b/global_step50/bf16_zero_pp_rank_7_mp_rank_00_optim_states.pt
new file mode 100644
index 0000000000000000000000000000000000000000..07602dbed2b09e82b32915d98a64b034e8853ed2
--- /dev/null
+++ b/global_step50/bf16_zero_pp_rank_7_mp_rank_00_optim_states.pt
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:285677a6afd776084e36816daf44448c6432d8f5ceb459c8e824b9cc66ee3f93
+size 12286638307
diff --git a/global_step50/zero_pp_rank_0_mp_rank_00_model_states.pt b/global_step50/zero_pp_rank_0_mp_rank_00_model_states.pt
new file mode 100644
index 0000000000000000000000000000000000000000..4f39a2afb8409d18cd9b9faf99c29c9d0247ef80
--- /dev/null
+++ b/global_step50/zero_pp_rank_0_mp_rank_00_model_states.pt
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:855f528cf0bab0d11da820f9e79cd4b5d15a5543b89de20def728410248138ed
+size 206444
diff --git a/global_step50/zero_pp_rank_1_mp_rank_00_model_states.pt b/global_step50/zero_pp_rank_1_mp_rank_00_model_states.pt
new file mode 100644
index 0000000000000000000000000000000000000000..968716667206253a977eeac632b5a9ea4da15541
--- /dev/null
+++ b/global_step50/zero_pp_rank_1_mp_rank_00_model_states.pt
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:455cf17a101b529d0d1996e16d033bae6206b3a723d9092bc3086d3b4b74f1e3
+size 206444
diff --git a/global_step50/zero_pp_rank_2_mp_rank_00_model_states.pt b/global_step50/zero_pp_rank_2_mp_rank_00_model_states.pt
new file mode 100644
index 0000000000000000000000000000000000000000..14b7e620752a941e3efeae52bc9f4b12d7767f23
--- /dev/null
+++ b/global_step50/zero_pp_rank_2_mp_rank_00_model_states.pt
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:a92ac45256f9a630ce36461320b438f72309ff28ee8911eb62c21b5ee1ec105a
+size 206444
diff --git a/global_step50/zero_pp_rank_3_mp_rank_00_model_states.pt b/global_step50/zero_pp_rank_3_mp_rank_00_model_states.pt
new file mode 100644
index 0000000000000000000000000000000000000000..878c342108481839e0dab7a45818b2070d900164
--- /dev/null
+++ b/global_step50/zero_pp_rank_3_mp_rank_00_model_states.pt
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:b87bfd3b93bf7fec20ce6c4b775d763a2d7622f7129dbb82a2c92fe6815e63da
+size 206444
diff --git a/global_step50/zero_pp_rank_4_mp_rank_00_model_states.pt b/global_step50/zero_pp_rank_4_mp_rank_00_model_states.pt
new file mode 100644
index 0000000000000000000000000000000000000000..a5aadd473ab9fd1cbc878dfb5131b64805742287
--- /dev/null
+++ b/global_step50/zero_pp_rank_4_mp_rank_00_model_states.pt
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ede2f27e0e87806776ed677b520ba54b19ae7a47b086560e1f9136559fda8323
+size 206444
diff --git a/global_step50/zero_pp_rank_5_mp_rank_00_model_states.pt b/global_step50/zero_pp_rank_5_mp_rank_00_model_states.pt
new file mode 100644
index 0000000000000000000000000000000000000000..bf4baa61269ad7dd2f496b95dfbf898d0ee54f87
--- /dev/null
+++ b/global_step50/zero_pp_rank_5_mp_rank_00_model_states.pt
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:fdc129e37acb075cbff9d3a13b72f497750b08a5dce4bfb09a2d120ae7d657cf
+size 206444
diff --git a/global_step50/zero_pp_rank_6_mp_rank_00_model_states.pt b/global_step50/zero_pp_rank_6_mp_rank_00_model_states.pt
new file mode 100644
index 0000000000000000000000000000000000000000..2febdfcd48af5d7bb0c559b036c5f9b213c0c356
--- /dev/null
+++ b/global_step50/zero_pp_rank_6_mp_rank_00_model_states.pt
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:f55a26882e198d70abf070567520679fdc7abc58a8a78c4cf0d0038aec3c0924
+size 206444
diff --git a/global_step50/zero_pp_rank_7_mp_rank_00_model_states.pt b/global_step50/zero_pp_rank_7_mp_rank_00_model_states.pt
new file mode 100644
index 0000000000000000000000000000000000000000..5c37cb2d40bbd3bd1ddc9636a08ee8b2c224e3ba
--- /dev/null
+++ b/global_step50/zero_pp_rank_7_mp_rank_00_model_states.pt
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:52ca437d03bc1c7d2e1496207bd403dab3e25b4f4bf53d039885675a6fe5e694
+size 206444
diff --git a/latest b/latest
new file mode 100644
index 0000000000000000000000000000000000000000..9b4dc801e3fb152ef5c0ee60d309c705a9b01564
--- /dev/null
+++ b/latest
@@ -0,0 +1 @@
+global_step50
\ No newline at end of file
diff --git a/logs/events.out.tfevents.1754432204.1506d310068f.1511638.0 b/logs/events.out.tfevents.1754432204.1506d310068f.1511638.0
new file mode 100644
index 0000000000000000000000000000000000000000..2411c477359c3dd2724e55d6d836136ff84a676b
--- /dev/null
+++ b/logs/events.out.tfevents.1754432204.1506d310068f.1511638.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e14ceabaf01edd8924fcf9d85eb276741254077f3cbb62108e12689b0066ac64
+size 23317
diff --git a/special_tokens_map.json b/special_tokens_map.json
new file mode 100644
index 0000000000000000000000000000000000000000..1d385d62cf08bca35254547902b792c243656ec1
--- /dev/null
+++ b/special_tokens_map.json
@@ -0,0 +1,23 @@
+{
+ "bos_token": {
+ "content": "<|begin▁of▁sentence|>",
+ "lstrip": false,
+ "normalized": false,
+ "rstrip": false,
+ "single_word": false
+ },
+ "eos_token": {
+ "content": "<|end▁of▁sentence|>",
+ "lstrip": false,
+ "normalized": false,
+ "rstrip": false,
+ "single_word": false
+ },
+ "pad_token": {
+ "content": "<|end▁of▁sentence|>",
+ "lstrip": false,
+ "normalized": false,
+ "rstrip": false,
+ "single_word": false
+ }
+}
diff --git a/tokenizer.json b/tokenizer.json
new file mode 100644
index 0000000000000000000000000000000000000000..d768bd0f05cc9821f74a87e2ec4e1229a09fa389
--- /dev/null
+++ b/tokenizer.json
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:93d5fd6d2f8cf1172ac86cf982e2b88fa6732366b44dc1a32349379a54a6a044
+size 11423346
diff --git a/tokenizer_config.json b/tokenizer_config.json
new file mode 100644
index 0000000000000000000000000000000000000000..c4657dc1d512391538fdb16dd905945ad376062f
--- /dev/null
+++ b/tokenizer_config.json
@@ -0,0 +1,242 @@
+{
+ "add_bos_token": false,
+ "add_eos_token": false,
+ "add_prefix_space": null,
+ "added_tokens_decoder": {
+ "151643": {
+ "content": "<|begin▁of▁sentence|>",
+ "lstrip": false,
+ "normalized": false,
+ "rstrip": false,
+ "single_word": false,
+ "special": true
+ },
+ "151644": {
+ "content": "<|im_start|>",
+ "lstrip": false,
+ "normalized": false,
+ "rstrip": false,
+ "single_word": false,
+ "special": true
+ },
+ "151645": {
+ "content": "<|end▁of▁sentence|>",
+ "lstrip": false,
+ "normalized": false,
+ "rstrip": false,
+ "single_word": false,
+ "special": true
+ },
+ "151646": {
+ "content": "<|object_ref_start|>",
+ "lstrip": false,
+ "normalized": false,
+ "rstrip": false,
+ "single_word": false,
+ "special": true
+ },
+ "151647": {
+ "content": "<|object_ref_end|>",
+ "lstrip": false,
+ "normalized": false,
+ "rstrip": false,
+ "single_word": false,
+ "special": true
+ },
+ "151648": {
+ "content": "<|box_start|>",
+ "lstrip": false,
+ "normalized": false,
+ "rstrip": false,
+ "single_word": false,
+ "special": true
+ },
+ "151649": {
+ "content": "<|box_end|>",
+ "lstrip": false,
+ "normalized": false,
+ "rstrip": false,
+ "single_word": false,
+ "special": true
+ },
+ "151650": {
+ "content": "<|quad_start|>",
+ "lstrip": false,
+ "normalized": false,
+ "rstrip": false,
+ "single_word": false,
+ "special": true
+ },
+ "151651": {
+ "content": "<|quad_end|>",
+ "lstrip": false,
+ "normalized": false,
+ "rstrip": false,
+ "single_word": false,
+ "special": true
+ },
+ "151652": {
+ "content": "<|vision_start|>",
+ "lstrip": false,
+ "normalized": false,
+ "rstrip": false,
+ "single_word": false,
+ "special": true
+ },
+ "151653": {
+ "content": "<|vision_end|>",
+ "lstrip": false,
+ "normalized": false,
+ "rstrip": false,
+ "single_word": false,
+ "special": true
+ },
+ "151654": {
+ "content": "<|vision_pad|>",
+ "lstrip": false,
+ "normalized": false,
+ "rstrip": false,
+ "single_word": false,
+ "special": true
+ },
+ "151655": {
+ "content": "<|image_pad|>",
+ "lstrip": false,
+ "normalized": false,
+ "rstrip": false,
+ "single_word": false,
+ "special": true
+ },
+ "151656": {
+ "content": "<|video_pad|>",
+ "lstrip": false,
+ "normalized": false,
+ "rstrip": false,
+ "single_word": false,
+ "special": true
+ },
+ "151657": {
+ "content": "",
+ "lstrip": false,
+ "normalized": false,
+ "rstrip": false,
+ "single_word": false,
+ "special": false
+ },
+ "151658": {
+ "content": "",
+ "lstrip": false,
+ "normalized": false,
+ "rstrip": false,
+ "single_word": false,
+ "special": false
+ },
+ "151659": {
+ "content": "<|fim_prefix|>",
+ "lstrip": false,
+ "normalized": false,
+ "rstrip": false,
+ "single_word": false,
+ "special": false
+ },
+ "151660": {
+ "content": "<|fim_middle|>",
+ "lstrip": false,
+ "normalized": false,
+ "rstrip": false,
+ "single_word": false,
+ "special": false
+ },
+ "151661": {
+ "content": "<|fim_suffix|>",
+ "lstrip": false,
+ "normalized": false,
+ "rstrip": false,
+ "single_word": false,
+ "special": false
+ },
+ "151662": {
+ "content": "<|fim_pad|>",
+ "lstrip": false,
+ "normalized": false,
+ "rstrip": false,
+ "single_word": false,
+ "special": false
+ },
+ "151663": {
+ "content": "<|repo_name|>",
+ "lstrip": false,
+ "normalized": false,
+ "rstrip": false,
+ "single_word": false,
+ "special": false
+ },
+ "151664": {
+ "content": "<|file_sep|>",
+ "lstrip": false,
+ "normalized": false,
+ "rstrip": false,
+ "single_word": false,
+ "special": false
+ },
+ "151665": {
+ "content": "",
+ "lstrip": false,
+ "normalized": false,
+ "rstrip": false,
+ "single_word": false,
+ "special": false
+ },
+ "151666": {
+ "content": "",
+ "lstrip": false,
+ "normalized": false,
+ "rstrip": false,
+ "single_word": false,
+ "special": false
+ },
+ "151667": {
+ "content": "",
+ "lstrip": false,
+ "normalized": false,
+ "rstrip": false,
+ "single_word": false,
+ "special": false
+ },
+ "151668": {
+ "content": "",
+ "lstrip": false,
+ "normalized": false,
+ "rstrip": false,
+ "single_word": false,
+ "special": false
+ },
+ "151669": {
+ "content": "<|User|>",
+ "lstrip": false,
+ "normalized": false,
+ "rstrip": false,
+ "single_word": false,
+ "special": false
+ },
+ "151670": {
+ "content": "<|Assistant|>",
+ "lstrip": false,
+ "normalized": false,
+ "rstrip": false,
+ "single_word": false,
+ "special": false
+ }
+ },
+ "bos_token": "<|begin▁of▁sentence|>",
+ "clean_up_tokenization_spaces": false,
+ "eos_token": "<|end▁of▁sentence|>",
+ "extra_special_tokens": {},
+ "legacy": true,
+ "model_max_length": 131072,
+ "pad_token": "<|end▁of▁sentence|>",
+ "sp_model_kwargs": {},
+ "tokenizer_class": "LlamaTokenizerFast",
+ "unk_token": null,
+ "use_default_system_prompt": false
+}
diff --git a/training_args.bin b/training_args.bin
new file mode 100644
index 0000000000000000000000000000000000000000..a275320c73f948be3ab0a894560d261ebe08c98a
--- /dev/null
+++ b/training_args.bin
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:840c9e203e24a719dec7c5250c5553e38d741952e631bfb2dfe7bddec9d6e991
+size 7633
diff --git a/zero_to_fp32.py b/zero_to_fp32.py
new file mode 100644
index 0000000000000000000000000000000000000000..0e759146cadd92ddfefab3680146c2bd6a2b5c04
--- /dev/null
+++ b/zero_to_fp32.py
@@ -0,0 +1,760 @@
+#!/usr/bin/env python
+
+# Copyright (c) Microsoft Corporation.
+# SPDX-License-Identifier: Apache-2.0
+
+# DeepSpeed Team
+
+# This script extracts fp32 consolidated weights from a zero 1, 2 and 3 DeepSpeed checkpoints. It gets
+# copied into the top level checkpoint dir, so the user can easily do the conversion at any point in
+# the future. Once extracted, the weights don't require DeepSpeed and can be used in any
+# application.
+#
+# example:
+# python zero_to_fp32.py . output_dir/
+# or
+# python zero_to_fp32.py . output_dir/ --safe_serialization
+
+import argparse
+import torch
+import glob
+import math
+import os
+import re
+import gc
+import json
+import numpy as np
+from tqdm import tqdm
+from collections import OrderedDict
+from dataclasses import dataclass
+
+# while this script doesn't use deepspeed to recover data, since the checkpoints are pickled with
+# DeepSpeed data structures it has to be available in the current python environment.
+from deepspeed.utils import logger
+from deepspeed.checkpoint.constants import (DS_VERSION, OPTIMIZER_STATE_DICT, SINGLE_PARTITION_OF_FP32_GROUPS,
+ FP32_FLAT_GROUPS, ZERO_STAGE, PARTITION_COUNT, PARAM_SHAPES, BUFFER_NAMES,
+ FROZEN_PARAM_SHAPES, FROZEN_PARAM_FRAGMENTS)
+
+
+@dataclass
+class zero_model_state:
+ buffers: dict()
+ param_shapes: dict()
+ shared_params: list
+ ds_version: int
+ frozen_param_shapes: dict()
+ frozen_param_fragments: dict()
+
+
+debug = 0
+
+# load to cpu
+device = torch.device('cpu')
+
+
+def atoi(text):
+ return int(text) if text.isdigit() else text
+
+
+def natural_keys(text):
+ '''
+ alist.sort(key=natural_keys) sorts in human order
+ http://nedbatchelder.com/blog/200712/human_sorting.html
+ (See Toothy's implementation in the comments)
+ '''
+ return [atoi(c) for c in re.split(r'(\d+)', text)]
+
+
+def get_model_state_file(checkpoint_dir, zero_stage):
+ if not os.path.isdir(checkpoint_dir):
+ raise FileNotFoundError(f"Directory '{checkpoint_dir}' doesn't exist")
+
+ # there should be only one file
+ if zero_stage <= 2:
+ file = os.path.join(checkpoint_dir, "mp_rank_00_model_states.pt")
+ elif zero_stage == 3:
+ file = os.path.join(checkpoint_dir, "zero_pp_rank_0_mp_rank_00_model_states.pt")
+
+ if not os.path.exists(file):
+ raise FileNotFoundError(f"can't find model states file at '{file}'")
+
+ return file
+
+
+def get_checkpoint_files(checkpoint_dir, glob_pattern):
+ # XXX: need to test that this simple glob rule works for multi-node setup too
+ ckpt_files = sorted(glob.glob(os.path.join(checkpoint_dir, glob_pattern)), key=natural_keys)
+
+ if len(ckpt_files) == 0:
+ raise FileNotFoundError(f"can't find {glob_pattern} files in directory '{checkpoint_dir}'")
+
+ return ckpt_files
+
+
+def get_optim_files(checkpoint_dir):
+ return get_checkpoint_files(checkpoint_dir, "*_optim_states.pt")
+
+
+def get_model_state_files(checkpoint_dir):
+ return get_checkpoint_files(checkpoint_dir, "*_model_states.pt")
+
+
+def parse_model_states(files):
+ zero_model_states = []
+ for file in files:
+ state_dict = torch.load(file, map_location=device, weights_only=False)
+
+ if BUFFER_NAMES not in state_dict:
+ raise ValueError(f"{file} is not a model state checkpoint")
+ buffer_names = state_dict[BUFFER_NAMES]
+ if debug:
+ print("Found buffers:", buffer_names)
+
+ # recover just the buffers while restoring them to fp32 if they were saved in fp16
+ buffers = {k: v.float() for k, v in state_dict["module"].items() if k in buffer_names}
+ param_shapes = state_dict[PARAM_SHAPES]
+
+ # collect parameters that are included in param_shapes
+ param_names = []
+ for s in param_shapes:
+ for name in s.keys():
+ param_names.append(name)
+
+ # update with frozen parameters
+ frozen_param_shapes = state_dict.get(FROZEN_PARAM_SHAPES, None)
+ if frozen_param_shapes is not None:
+ if debug:
+ print(f"Found frozen_param_shapes: {frozen_param_shapes}")
+ param_names += list(frozen_param_shapes.keys())
+
+ # handle shared params
+ shared_params = [[k, v] for k, v in state_dict["shared_params"].items()]
+
+ ds_version = state_dict.get(DS_VERSION, None)
+
+ frozen_param_fragments = state_dict.get(FROZEN_PARAM_FRAGMENTS, None)
+
+ z_model_state = zero_model_state(buffers=buffers,
+ param_shapes=param_shapes,
+ shared_params=shared_params,
+ ds_version=ds_version,
+ frozen_param_shapes=frozen_param_shapes,
+ frozen_param_fragments=frozen_param_fragments)
+ zero_model_states.append(z_model_state)
+
+ return zero_model_states
+
+
+def parse_optim_states(files, ds_checkpoint_dir):
+ total_files = len(files)
+ state_dicts = []
+ for f in tqdm(files, desc='Loading checkpoint shards'):
+ state_dict = torch.load(f, map_location=device, mmap=True, weights_only=False)
+ # immediately discard the potentially huge 2 optimizer states as we only care for fp32 master weights
+ # and also handle the case where it was already removed by another helper script
+ state_dict["optimizer_state_dict"].pop("optimizer_state_dict", None)
+ state_dicts.append(state_dict)
+
+ if not ZERO_STAGE in state_dicts[0][OPTIMIZER_STATE_DICT]:
+ raise ValueError(f"{files[0]} is not a zero checkpoint")
+ zero_stage = state_dicts[0][OPTIMIZER_STATE_DICT][ZERO_STAGE]
+ world_size = state_dicts[0][OPTIMIZER_STATE_DICT][PARTITION_COUNT]
+
+ # For ZeRO-2 each param group can have different partition_count as data parallelism for expert
+ # parameters can be different from data parallelism for non-expert parameters. So we can just
+ # use the max of the partition_count to get the dp world_size.
+
+ if type(world_size) is list:
+ world_size = max(world_size)
+
+ if world_size != total_files:
+ raise ValueError(
+ f"Expected {world_size} of '*_optim_states.pt' under '{ds_checkpoint_dir}' but found {total_files} files. "
+ "Possibly due to an overwrite of an old checkpoint, or a checkpoint didn't get saved by one or more processes."
+ )
+
+ # the groups are named differently in each stage
+ if zero_stage <= 2:
+ fp32_groups_key = SINGLE_PARTITION_OF_FP32_GROUPS
+ elif zero_stage == 3:
+ fp32_groups_key = FP32_FLAT_GROUPS
+ else:
+ raise ValueError(f"unknown zero stage {zero_stage}")
+
+ fp32_flat_groups = [state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key] for i in range(len(state_dicts))]
+ return zero_stage, world_size, fp32_flat_groups
+
+
+def _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir, exclude_frozen_parameters):
+ """
+ Returns fp32 state_dict reconstructed from ds checkpoint
+
+ Args:
+ - ``ds_checkpoint_dir``: path to the deepspeed checkpoint folder (where the optimizer files are)
+
+ """
+ print(f"Processing zero checkpoint '{ds_checkpoint_dir}'")
+
+ optim_files = get_optim_files(ds_checkpoint_dir)
+ zero_stage, world_size, fp32_flat_groups = parse_optim_states(optim_files, ds_checkpoint_dir)
+ print(f"Detected checkpoint of type zero stage {zero_stage}, world_size: {world_size}")
+
+ model_files = get_model_state_files(ds_checkpoint_dir)
+
+ zero_model_states = parse_model_states(model_files)
+ print(f'Parsing checkpoint created by deepspeed=={zero_model_states[0].ds_version}')
+
+ if zero_stage <= 2:
+ return _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states,
+ exclude_frozen_parameters)
+ elif zero_stage == 3:
+ return _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states,
+ exclude_frozen_parameters)
+
+
+def _zero2_merge_frozen_params(state_dict, zero_model_states):
+ if zero_model_states[0].frozen_param_shapes is None or len(zero_model_states[0].frozen_param_shapes) == 0:
+ return
+
+ frozen_param_shapes = zero_model_states[0].frozen_param_shapes
+ frozen_param_fragments = zero_model_states[0].frozen_param_fragments
+
+ if debug:
+ num_elem = sum(s.numel() for s in frozen_param_shapes.values())
+ print(f'rank 0: {FROZEN_PARAM_SHAPES}.numel = {num_elem}')
+
+ wanted_params = len(frozen_param_shapes)
+ wanted_numel = sum(s.numel() for s in frozen_param_shapes.values())
+ avail_numel = sum([p.numel() for p in frozen_param_fragments.values()])
+ print(f'Frozen params: Have {avail_numel} numels to process.')
+ print(f'Frozen params: Need {wanted_numel} numels in {wanted_params} params')
+
+ total_params = 0
+ total_numel = 0
+ for name, shape in frozen_param_shapes.items():
+ total_params += 1
+ unpartitioned_numel = shape.numel()
+ total_numel += unpartitioned_numel
+
+ state_dict[name] = frozen_param_fragments[name]
+
+ if debug:
+ print(f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} ")
+
+ print(f"Reconstructed Frozen fp32 state dict with {total_params} params {total_numel} elements")
+
+
+def _has_callable(obj, fn):
+ attr = getattr(obj, fn, None)
+ return callable(attr)
+
+
+def _zero2_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states):
+ param_shapes = zero_model_states[0].param_shapes
+
+ # Reconstruction protocol:
+ #
+ # XXX: document this
+
+ if debug:
+ for i in range(world_size):
+ for j in range(len(fp32_flat_groups[0])):
+ print(f"{FP32_FLAT_GROUPS}[{i}][{j}].shape={fp32_flat_groups[i][j].shape}")
+
+ # XXX: memory usage doubles here (zero2)
+ num_param_groups = len(fp32_flat_groups[0])
+ merged_single_partition_of_fp32_groups = []
+ for i in range(num_param_groups):
+ merged_partitions = [sd[i] for sd in fp32_flat_groups]
+ full_single_fp32_vector = torch.cat(merged_partitions, 0)
+ merged_single_partition_of_fp32_groups.append(full_single_fp32_vector)
+ avail_numel = sum(
+ [full_single_fp32_vector.numel() for full_single_fp32_vector in merged_single_partition_of_fp32_groups])
+
+ if debug:
+ wanted_params = sum([len(shapes) for shapes in param_shapes])
+ wanted_numel = sum([sum(shape.numel() for shape in shapes.values()) for shapes in param_shapes])
+ # not asserting if there is a mismatch due to possible padding
+ print(f"Have {avail_numel} numels to process.")
+ print(f"Need {wanted_numel} numels in {wanted_params} params.")
+
+ # params
+ # XXX: for huge models that can't fit into the host's RAM we will have to recode this to support
+ # out-of-core computing solution
+ total_numel = 0
+ total_params = 0
+ for shapes, full_single_fp32_vector in zip(param_shapes, merged_single_partition_of_fp32_groups):
+ offset = 0
+ avail_numel = full_single_fp32_vector.numel()
+ for name, shape in shapes.items():
+
+ unpartitioned_numel = shape.numel() if _has_callable(shape, 'numel') else math.prod(shape)
+ total_numel += unpartitioned_numel
+ total_params += 1
+
+ if debug:
+ print(f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} ")
+ state_dict[name] = full_single_fp32_vector.narrow(0, offset, unpartitioned_numel).view(shape)
+ offset += unpartitioned_numel
+
+ # Z2 started to align to 2*world_size to improve nccl performance. Therefore both offset and
+ # avail_numel can differ by anywhere between 0..2*world_size. Due to two unrelated complex
+ # paddings performed in the code it's almost impossible to predict the exact numbers w/o the
+ # live optimizer object, so we are checking that the numbers are within the right range
+ align_to = 2 * world_size
+
+ def zero2_align(x):
+ return align_to * math.ceil(x / align_to)
+
+ if debug:
+ print(f"original offset={offset}, avail_numel={avail_numel}")
+
+ offset = zero2_align(offset)
+ avail_numel = zero2_align(avail_numel)
+
+ if debug:
+ print(f"aligned offset={offset}, avail_numel={avail_numel}")
+
+ # Sanity check
+ if offset != avail_numel:
+ raise ValueError(f"consumed {offset} numels out of {avail_numel} - something is wrong")
+
+ print(f"Reconstructed fp32 state dict with {total_params} params {total_numel} elements")
+
+
+def _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states,
+ exclude_frozen_parameters):
+ state_dict = OrderedDict()
+
+ # buffers
+ buffers = zero_model_states[0].buffers
+ state_dict.update(buffers)
+ if debug:
+ print(f"added {len(buffers)} buffers")
+
+ if not exclude_frozen_parameters:
+ _zero2_merge_frozen_params(state_dict, zero_model_states)
+
+ _zero2_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states)
+
+ # recover shared parameters
+ for pair in zero_model_states[0].shared_params:
+ if pair[1] in state_dict:
+ state_dict[pair[0]] = state_dict[pair[1]]
+
+ return state_dict
+
+
+def zero3_partitioned_param_info(unpartitioned_numel, world_size):
+ remainder = unpartitioned_numel % world_size
+ padding_numel = (world_size - remainder) if remainder else 0
+ partitioned_numel = math.ceil(unpartitioned_numel / world_size)
+ return partitioned_numel, padding_numel
+
+
+def _zero3_merge_frozen_params(state_dict, world_size, zero_model_states):
+ if zero_model_states[0].frozen_param_shapes is None or len(zero_model_states[0].frozen_param_shapes) == 0:
+ return
+
+ if debug:
+ for i in range(world_size):
+ num_elem = sum(s.numel() for s in zero_model_states[i].frozen_param_fragments.values())
+ print(f'rank {i}: {FROZEN_PARAM_SHAPES}.numel = {num_elem}')
+
+ frozen_param_shapes = zero_model_states[0].frozen_param_shapes
+ wanted_params = len(frozen_param_shapes)
+ wanted_numel = sum(s.numel() for s in frozen_param_shapes.values())
+ avail_numel = sum([p.numel() for p in zero_model_states[0].frozen_param_fragments.values()]) * world_size
+ print(f'Frozen params: Have {avail_numel} numels to process.')
+ print(f'Frozen params: Need {wanted_numel} numels in {wanted_params} params')
+
+ total_params = 0
+ total_numel = 0
+ for name, shape in zero_model_states[0].frozen_param_shapes.items():
+ total_params += 1
+ unpartitioned_numel = shape.numel()
+ total_numel += unpartitioned_numel
+
+ param_frags = tuple(model_state.frozen_param_fragments[name] for model_state in zero_model_states)
+ state_dict[name] = torch.cat(param_frags, 0).narrow(0, 0, unpartitioned_numel).view(shape)
+
+ partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(unpartitioned_numel, world_size)
+
+ if debug:
+ print(
+ f"Frozen params: {total_params} {name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}"
+ )
+
+ print(f"Reconstructed Frozen fp32 state dict with {total_params} params {total_numel} elements")
+
+
+class GatheredTensor:
+ """
+ A pseudo tensor that collects partitioned weights.
+ It is more memory efficient when there are multiple groups.
+ """
+
+ def __init__(self, flat_groups, flat_groups_offset, offset, partitioned_numel, shape):
+ self.flat_groups = flat_groups
+ self.flat_groups_offset = flat_groups_offset
+ self.offset = offset
+ self.partitioned_numel = partitioned_numel
+ self.shape = shape
+ self.dtype = self.flat_groups[0][0].dtype
+
+ def contiguous(self):
+ """
+ Merge partitioned weights from flat_groups into a single tensor.
+ """
+ end_idx = self.offset + self.partitioned_numel
+ world_size = len(self.flat_groups)
+ pad_flat_param_chunks = []
+
+ for rank_i in range(world_size):
+ # for each rank, we need to collect weights from related group/groups
+ flat_groups_at_rank_i = self.flat_groups[rank_i]
+ start_group_id = None
+ end_group_id = None
+ for group_id in range(len(self.flat_groups_offset)):
+ if self.flat_groups_offset[group_id] <= self.offset < self.flat_groups_offset[group_id + 1]:
+ start_group_id = group_id
+ if self.flat_groups_offset[group_id] < end_idx <= self.flat_groups_offset[group_id + 1]:
+ end_group_id = group_id
+ break
+ # collect weights from related group/groups
+ for group_id in range(start_group_id, end_group_id + 1):
+ flat_tensor = flat_groups_at_rank_i[group_id]
+ start_offset = self.offset - self.flat_groups_offset[group_id]
+ end_offset = min(end_idx, self.flat_groups_offset[group_id + 1]) - self.flat_groups_offset[group_id]
+ pad_flat_param_chunks.append(flat_tensor[start_offset:end_offset])
+
+ # collect weights from all ranks
+ pad_flat_param = torch.cat(pad_flat_param_chunks, dim=0)
+ param = pad_flat_param[:self.shape.numel()].view(self.shape).contiguous()
+ return param
+
+
+def _zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states):
+ param_shapes = zero_model_states[0].param_shapes
+ avail_numel = sum([flat_group.numel() for flat_group in fp32_flat_groups[0]]) * world_size
+
+ # Reconstruction protocol: For zero3 we need to zip the partitions together at boundary of each
+ # param, re-consolidating each param, while dealing with padding if any
+
+ # merge list of dicts, preserving order
+ param_shapes = {k: v for d in param_shapes for k, v in d.items()}
+
+ if debug:
+ for i in range(world_size):
+ print(f"{FP32_FLAT_GROUPS}[{i}].shape={fp32_flat_groups[i].shape}")
+
+ wanted_params = len(param_shapes)
+ wanted_numel = sum(shape.numel() for shape in param_shapes.values())
+ # not asserting if there is a mismatch due to possible padding
+ avail_numel = fp32_flat_groups[0].numel() * world_size
+ print(f"Trainable params: Have {avail_numel} numels to process.")
+ print(f"Trainable params: Need {wanted_numel} numels in {wanted_params} params.")
+
+ # params
+ # XXX: for huge models that can't fit into the host's RAM we will have to recode this to support
+ # out-of-core computing solution
+ offset = 0
+ total_numel = 0
+ total_params = 0
+ flat_groups_offset = [0] + list(np.cumsum([flat_tensor.numel() for flat_tensor in fp32_flat_groups[0]]))
+ for name, shape in tqdm(param_shapes.items(), desc='Gathering sharded weights'):
+ unpartitioned_numel = shape.numel()
+ total_numel += unpartitioned_numel
+ total_params += 1
+ partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(unpartitioned_numel, world_size)
+
+ if debug:
+ print(
+ f"Trainable params: {total_params} {name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}"
+ )
+
+ # memory efficient tensor
+ tensor = GatheredTensor(fp32_flat_groups, flat_groups_offset, offset, partitioned_numel, shape)
+ state_dict[name] = tensor
+ offset += partitioned_numel
+
+ offset *= world_size
+
+ # Sanity check
+ if offset != avail_numel:
+ raise ValueError(f"consumed {offset} numels out of {avail_numel} - something is wrong")
+
+ print(f"Reconstructed Trainable fp32 state dict with {total_params} params {total_numel} elements")
+
+
+def _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states,
+ exclude_frozen_parameters):
+ state_dict = OrderedDict()
+
+ # buffers
+ buffers = zero_model_states[0].buffers
+ state_dict.update(buffers)
+ if debug:
+ print(f"added {len(buffers)} buffers")
+
+ if not exclude_frozen_parameters:
+ _zero3_merge_frozen_params(state_dict, world_size, zero_model_states)
+
+ _zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states)
+
+ # recover shared parameters
+ for pair in zero_model_states[0].shared_params:
+ if pair[1] in state_dict:
+ state_dict[pair[0]] = state_dict[pair[1]]
+
+ return state_dict
+
+
+def to_torch_tensor(state_dict, return_empty_tensor=False):
+ """
+ Convert state_dict of GatheredTensor to torch tensor
+ """
+ torch_state_dict = {}
+ converted_tensors = {}
+ for name, tensor in state_dict.items():
+ tensor_id = id(tensor)
+ if tensor_id in converted_tensors: # shared tensors
+ shared_tensor = torch_state_dict[converted_tensors[tensor_id]]
+ torch_state_dict[name] = shared_tensor
+ else:
+ converted_tensors[tensor_id] = name
+ if return_empty_tensor:
+ torch_state_dict[name] = torch.empty(tensor.shape, dtype=tensor.dtype)
+ else:
+ torch_state_dict[name] = tensor.contiguous()
+ return torch_state_dict
+
+
+def get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir,
+ tag=None,
+ exclude_frozen_parameters=False,
+ lazy_mode=False):
+ """
+ Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated state_dict that can be loaded with
+ ``load_state_dict()`` and used for training without DeepSpeed or shared with others, for example
+ via a model hub.
+
+ Args:
+ - ``checkpoint_dir``: path to the desired checkpoint folder
+ - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in 'latest' file. e.g., ``global_step14``
+ - ``exclude_frozen_parameters``: exclude frozen parameters
+ - ``lazy_mode``: get state_dict in lazy mode. It returns a dict of pesduo tensor instead of torch tensor, which is more memory efficient.
+ Convert the pesduo tensor to torch tensor by ``.contiguous()``
+
+ Returns:
+ - pytorch ``state_dict``
+
+ A typical usage might be ::
+
+ from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint
+ # do the training and checkpoint saving
+ state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir) # already on cpu
+ model = model.cpu() # move to cpu
+ model.load_state_dict(state_dict)
+ # submit to model hub or save the model to share with others
+
+ In this example the ``model`` will no longer be usable in the deepspeed context of the same
+ application. i.e. you will need to re-initialize the deepspeed engine, since
+ ``model.load_state_dict(state_dict)`` will remove all the deepspeed magic from it.
+
+ If you want it all done for you, use ``load_state_dict_from_zero_checkpoint`` instead.
+
+ Note: the above usage may not work if your application doesn't have sufficient free CPU memory.
+ You may need to use the offline approach using the ``zero_to_fp32.py`` script that is saved with
+ the checkpoint. Or you can load state_dict in lazy mode ::
+
+ from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint
+ state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, lazy_mode=True) # not on cpu
+ for name, lazy_tensor in state_dict.item():
+ tensor = lazy_tensor.contiguous() # to cpu
+ print(name, tensor)
+ # del tensor to release memory if it no longer in use
+ """
+ if tag is None:
+ latest_path = os.path.join(checkpoint_dir, 'latest')
+ if os.path.isfile(latest_path):
+ with open(latest_path, 'r') as fd:
+ tag = fd.read().strip()
+ else:
+ raise ValueError(f"Unable to find 'latest' file at {latest_path}")
+
+ ds_checkpoint_dir = os.path.join(checkpoint_dir, tag)
+
+ if not os.path.isdir(ds_checkpoint_dir):
+ raise FileNotFoundError(f"Directory '{ds_checkpoint_dir}' doesn't exist")
+
+ state_dict = _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir, exclude_frozen_parameters)
+ if lazy_mode:
+ return state_dict
+ else:
+ return to_torch_tensor(state_dict)
+
+
+def convert_zero_checkpoint_to_fp32_state_dict(checkpoint_dir,
+ output_dir,
+ max_shard_size="5GB",
+ safe_serialization=False,
+ tag=None,
+ exclude_frozen_parameters=False):
+ """
+ Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict`` file that can be
+ loaded with ``torch.load(file)`` + ``load_state_dict()`` and used for training without DeepSpeed.
+
+ Args:
+ - ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``)
+ - ``output_dir``: directory to the pytorch fp32 state_dict output files
+ - ``max_shard_size``: the maximum size for a checkpoint before being sharded, default value is 5GB
+ - ``safe_serialization``: whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).
+ - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14``
+ - ``exclude_frozen_parameters``: exclude frozen parameters
+ """
+
+ # Dependency pre-check
+ if safe_serialization:
+ try:
+ from safetensors.torch import save_file
+ except ImportError:
+ print('If you want to use `safe_serialization`, please `pip install safetensors`')
+ raise
+ if max_shard_size is not None:
+ try:
+ from huggingface_hub import split_torch_state_dict_into_shards
+ except ImportError:
+ print('If you want to use `max_shard_size`, please `pip install huggingface_hub`')
+ raise
+
+ # Convert zero checkpoint to state_dict
+ state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir,
+ tag,
+ exclude_frozen_parameters,
+ lazy_mode=True)
+
+ # Shard the model if it is too big.
+ weights_name = "model.safetensors" if safe_serialization else "pytorch_model.bin"
+ if max_shard_size is not None:
+ filename_pattern = weights_name.replace(".bin", "{suffix}.bin").replace(".safetensors", "{suffix}.safetensors")
+ # an memory-efficient approach for sharding
+ empty_state_dict = to_torch_tensor(state_dict, return_empty_tensor=True)
+ state_dict_split = split_torch_state_dict_into_shards(empty_state_dict,
+ filename_pattern=filename_pattern,
+ max_shard_size=max_shard_size)
+ else:
+ from collections import namedtuple
+ StateDictSplit = namedtuple("StateDictSplit", ["is_sharded", "filename_to_tensors"])
+ state_dict_split = StateDictSplit(is_sharded=False,
+ filename_to_tensors={weights_name: list(state_dict.keys())})
+
+ # Save the model by shard
+ os.makedirs(output_dir, exist_ok=True)
+ filename_to_tensors = state_dict_split.filename_to_tensors.items()
+ for shard_file, tensors in tqdm(filename_to_tensors, desc="Saving checkpoint shards"):
+ shard_state_dict = {tensor_name: state_dict[tensor_name] for tensor_name in tensors}
+ shard_state_dict = to_torch_tensor(shard_state_dict)
+ output_path = os.path.join(output_dir, shard_file)
+ if safe_serialization:
+ save_file(shard_state_dict, output_path, metadata={"format": "pt"})
+ else:
+ torch.save(shard_state_dict, output_path)
+ # release the memory of current shard
+ for tensor_name in list(shard_state_dict.keys()):
+ del state_dict[tensor_name]
+ del shard_state_dict[tensor_name]
+ del shard_state_dict
+ gc.collect()
+
+ # Save index if sharded
+ if state_dict_split.is_sharded:
+ index = {
+ "metadata": state_dict_split.metadata,
+ "weight_map": state_dict_split.tensor_to_filename,
+ }
+ save_index_file = "model.safetensors.index.json" if safe_serialization else "pytorch_model.bin.index.json"
+ save_index_file = os.path.join(output_dir, save_index_file)
+ with open(save_index_file, "w", encoding="utf-8") as f:
+ content = json.dumps(index, indent=2, sort_keys=True) + "\n"
+ f.write(content)
+
+
+def load_state_dict_from_zero_checkpoint(model, checkpoint_dir, tag=None):
+ """
+ 1. Put the provided model to cpu
+ 2. Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict``
+ 3. Load it into the provided model
+
+ Args:
+ - ``model``: the model object to update
+ - ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``)
+ - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14``
+
+ Returns:
+ - ``model`: modified model
+
+ Make sure you have plenty of CPU memory available before you call this function. If you don't
+ have enough use the ``zero_to_fp32.py`` utility to do the conversion. You will find it
+ conveniently placed for you in the checkpoint folder.
+
+ A typical usage might be ::
+
+ from deepspeed.utils.zero_to_fp32 import load_state_dict_from_zero_checkpoint
+ model = load_state_dict_from_zero_checkpoint(trainer.model, checkpoint_dir)
+ # submit to model hub or save the model to share with others
+
+ Note, that once this was run, the ``model`` will no longer be usable in the deepspeed context
+ of the same application. i.e. you will need to re-initialize the deepspeed engine, since
+ ``model.load_state_dict(state_dict)`` will remove all the deepspeed magic from it.
+
+ """
+ logger.info(f"Extracting fp32 weights")
+ state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag)
+
+ logger.info(f"Overwriting model with fp32 weights")
+ model = model.cpu()
+ model.load_state_dict(state_dict, strict=False)
+
+ return model
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("checkpoint_dir",
+ type=str,
+ help="path to the desired checkpoint folder, e.g., path/checkpoint-12")
+ parser.add_argument("output_dir",
+ type=str,
+ help="directory to the pytorch fp32 state_dict output files"
+ "(e.g. path/checkpoint-12-output/)")
+ parser.add_argument(
+ "--max_shard_size",
+ type=str,
+ default="5GB",
+ help="The maximum size for a checkpoint before being sharded. Checkpoints shard will then be each of size"
+ "lower than this size. If expressed as a string, needs to be digits followed by a unit (like `5MB`"
+ "We default it to 5GB in order for models to be able to run easily on free-tier google colab instances"
+ "without CPU OOM issues.")
+ parser.add_argument(
+ "--safe_serialization",
+ default=False,
+ action='store_true',
+ help="Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).")
+ parser.add_argument("-t",
+ "--tag",
+ type=str,
+ default=None,
+ help="checkpoint tag used as a unique identifier for checkpoint. e.g., global_step1")
+ parser.add_argument("--exclude_frozen_parameters", action='store_true', help="exclude frozen parameters")
+ parser.add_argument("-d", "--debug", action='store_true', help="enable debug")
+ args = parser.parse_args()
+
+ debug = args.debug
+
+ convert_zero_checkpoint_to_fp32_state_dict(args.checkpoint_dir,
+ args.output_dir,
+ max_shard_size=args.max_shard_size,
+ safe_serialization=args.safe_serialization,
+ tag=args.tag,
+ exclude_frozen_parameters=args.exclude_frozen_parameters)