diff --git a/100m_new/checkpoints_arm_best_on_100m/global_step9532/bf16_zero_pp_rank_0_mp_rank_00_optim_states.pt b/100m_new/checkpoints_arm_best_on_100m/global_step9532/bf16_zero_pp_rank_0_mp_rank_00_optim_states.pt new file mode 100644 index 0000000000000000000000000000000000000000..a551f29e64e73ebc6e885647c943f2c7b92784a1 --- /dev/null +++ b/100m_new/checkpoints_arm_best_on_100m/global_step9532/bf16_zero_pp_rank_0_mp_rank_00_optim_states.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:803edb191c60e7b337a4680b71ac2ec51b96efa94c9a8df64f386978c7c45449 +size 217242608 diff --git a/100m_new/checkpoints_arm_best_on_100m/global_step9532/bf16_zero_pp_rank_1_mp_rank_00_optim_states.pt b/100m_new/checkpoints_arm_best_on_100m/global_step9532/bf16_zero_pp_rank_1_mp_rank_00_optim_states.pt new file mode 100644 index 0000000000000000000000000000000000000000..861b54053b2d52233fbc492005c13898af126e5d --- /dev/null +++ b/100m_new/checkpoints_arm_best_on_100m/global_step9532/bf16_zero_pp_rank_1_mp_rank_00_optim_states.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9240b138accadb839698cb3c2d12e37473a192919601cf698a20ed89d61aabb1 +size 217242672 diff --git a/100m_new/checkpoints_arm_best_on_100m/global_step9532/bf16_zero_pp_rank_2_mp_rank_00_optim_states.pt b/100m_new/checkpoints_arm_best_on_100m/global_step9532/bf16_zero_pp_rank_2_mp_rank_00_optim_states.pt new file mode 100644 index 0000000000000000000000000000000000000000..30b9250dc0c1deda208bbb8f9d5cb008f2893a77 --- /dev/null +++ b/100m_new/checkpoints_arm_best_on_100m/global_step9532/bf16_zero_pp_rank_2_mp_rank_00_optim_states.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5ca6623a05ca3b43e6f89e9b97c7f07b63abb89bafe2de8c59a2284be4a0d15b +size 217243248 diff --git a/100m_new/checkpoints_arm_best_on_100m/global_step9532/bf16_zero_pp_rank_3_mp_rank_00_optim_states.pt b/100m_new/checkpoints_arm_best_on_100m/global_step9532/bf16_zero_pp_rank_3_mp_rank_00_optim_states.pt new file mode 100644 index 0000000000000000000000000000000000000000..d98c85e4bd0d679d6b3245f81517b41ba490e524 --- /dev/null +++ b/100m_new/checkpoints_arm_best_on_100m/global_step9532/bf16_zero_pp_rank_3_mp_rank_00_optim_states.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ab58379d16aa40ca7eaaca2bcc7f5b78dc2ee25173bb907949a3a5f7781fbc9e +size 217243312 diff --git a/100m_new/checkpoints_arm_best_on_100m/global_step9532/bf16_zero_pp_rank_4_mp_rank_00_optim_states.pt b/100m_new/checkpoints_arm_best_on_100m/global_step9532/bf16_zero_pp_rank_4_mp_rank_00_optim_states.pt new file mode 100644 index 0000000000000000000000000000000000000000..4dbf9b9abe1e1654bd2dfc0be3e11487d72693fb --- /dev/null +++ b/100m_new/checkpoints_arm_best_on_100m/global_step9532/bf16_zero_pp_rank_4_mp_rank_00_optim_states.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:34788b36148c5ae8937cfb235ee5444447efb72886042b29b55fc32dfb93d96e +size 217243376 diff --git a/100m_new/checkpoints_arm_best_on_100m/global_step9532/bf16_zero_pp_rank_5_mp_rank_00_optim_states.pt b/100m_new/checkpoints_arm_best_on_100m/global_step9532/bf16_zero_pp_rank_5_mp_rank_00_optim_states.pt new file mode 100644 index 0000000000000000000000000000000000000000..bd2720c685c2a9f4c759c05c1020596c182c8ec8 --- /dev/null +++ b/100m_new/checkpoints_arm_best_on_100m/global_step9532/bf16_zero_pp_rank_5_mp_rank_00_optim_states.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7c4823f0f2a29fb51a0b9ba31b65c6d1d9af9a779de9c9ef34864bb5e9cef68e +size 217243312 diff --git a/100m_new/checkpoints_arm_best_on_100m/global_step9532/bf16_zero_pp_rank_6_mp_rank_00_optim_states.pt b/100m_new/checkpoints_arm_best_on_100m/global_step9532/bf16_zero_pp_rank_6_mp_rank_00_optim_states.pt new file mode 100644 index 0000000000000000000000000000000000000000..9ab7a072bc10dfe12fb0597f8e0284b66e29ac38 --- /dev/null +++ b/100m_new/checkpoints_arm_best_on_100m/global_step9532/bf16_zero_pp_rank_6_mp_rank_00_optim_states.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ee88e69a3043452c0aafdc22bf440897dc0e1c6821db84dc795de75887d70d81 +size 217243376 diff --git a/100m_new/checkpoints_arm_best_on_100m/global_step9532/bf16_zero_pp_rank_7_mp_rank_00_optim_states.pt b/100m_new/checkpoints_arm_best_on_100m/global_step9532/bf16_zero_pp_rank_7_mp_rank_00_optim_states.pt new file mode 100644 index 0000000000000000000000000000000000000000..67821c8c2bcd2d550ae2b2e58e593d5bc1f536eb --- /dev/null +++ b/100m_new/checkpoints_arm_best_on_100m/global_step9532/bf16_zero_pp_rank_7_mp_rank_00_optim_states.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c4e3b35c740395e296accd7e3bfa51a836b43988953703a24fddba545ecfec1f +size 217243184 diff --git a/100m_new/checkpoints_arm_best_on_100m/global_step9532/layer_01-model_00-model_states.pt b/100m_new/checkpoints_arm_best_on_100m/global_step9532/layer_01-model_00-model_states.pt new file mode 100644 index 0000000000000000000000000000000000000000..8c9ce198ebad2a82a4445b7169faa21ecfaa7b86 --- /dev/null +++ b/100m_new/checkpoints_arm_best_on_100m/global_step9532/layer_01-model_00-model_states.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7c481dfda78247ba4a17e96183fc5569517a3679277735fb4cbfaaeed43bcb05 +size 77268367 diff --git a/100m_new/checkpoints_arm_best_on_100m/global_step9532/layer_02-model_00-model_states.pt b/100m_new/checkpoints_arm_best_on_100m/global_step9532/layer_02-model_00-model_states.pt new file mode 100644 index 0000000000000000000000000000000000000000..1afba3385dd7cd0150d3719ddab2a0b2043722e7 --- /dev/null +++ b/100m_new/checkpoints_arm_best_on_100m/global_step9532/layer_02-model_00-model_states.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b40121260d7debe74fd2fe168bd4c0b6df4ef48b38df66a05b60ec640f984820 +size 14161774 diff --git a/100m_new/checkpoints_arm_best_on_100m/global_step9532/layer_03-model_00-model_states.pt b/100m_new/checkpoints_arm_best_on_100m/global_step9532/layer_03-model_00-model_states.pt new file mode 100644 index 0000000000000000000000000000000000000000..86d187d23e5c2f1567b7cffdf9d015b4f06528be --- /dev/null +++ b/100m_new/checkpoints_arm_best_on_100m/global_step9532/layer_03-model_00-model_states.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a47d2b6b3c6601af9238b1356c2ee744d13e6d0e2db5736bbdefd56cf8be7298 +size 14161774 diff --git a/100m_new/checkpoints_arm_best_on_100m/global_step9532/layer_04-model_00-model_states.pt b/100m_new/checkpoints_arm_best_on_100m/global_step9532/layer_04-model_00-model_states.pt new file mode 100644 index 0000000000000000000000000000000000000000..2405b2e17ee55c7654646aaa8139f69927f268d8 --- /dev/null +++ b/100m_new/checkpoints_arm_best_on_100m/global_step9532/layer_04-model_00-model_states.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:44e06c69d22c8c76eccc4498b93d5c02130f30da08e0d824ded6c9bd347c0718 +size 14161774 diff --git a/100m_new/checkpoints_arm_best_on_100m/global_step9532/layer_05-model_00-model_states.pt b/100m_new/checkpoints_arm_best_on_100m/global_step9532/layer_05-model_00-model_states.pt new file mode 100644 index 0000000000000000000000000000000000000000..47b7fd2ec40ad248809ad6053b277a4b68434564 --- /dev/null +++ b/100m_new/checkpoints_arm_best_on_100m/global_step9532/layer_05-model_00-model_states.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c6ac6db57e80164e1aa5ca31f5cf3c3cf10721051a926d6336336f4585487c2c +size 14161774 diff --git a/100m_new/checkpoints_arm_best_on_100m/global_step9532/layer_06-model_00-model_states.pt b/100m_new/checkpoints_arm_best_on_100m/global_step9532/layer_06-model_00-model_states.pt new file mode 100644 index 0000000000000000000000000000000000000000..7fb776c58107662b3fd301f487f57fc686c8a959 --- /dev/null +++ b/100m_new/checkpoints_arm_best_on_100m/global_step9532/layer_06-model_00-model_states.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:921f33567fd55363333af15b435aa78096a7a970275a7c15892d170726993d6c +size 14161774 diff --git a/100m_new/checkpoints_arm_best_on_100m/global_step9532/layer_07-model_00-model_states.pt b/100m_new/checkpoints_arm_best_on_100m/global_step9532/layer_07-model_00-model_states.pt new file mode 100644 index 0000000000000000000000000000000000000000..7908dca05cfb38d07a1be845576ec3ade71e0657 --- /dev/null +++ b/100m_new/checkpoints_arm_best_on_100m/global_step9532/layer_07-model_00-model_states.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:28006fb55f7cca41e0fddc4db86f0cc17a260b87361de116c075c09a4ff56d44 +size 14161774 diff --git a/100m_new/checkpoints_arm_best_on_100m/global_step9532/layer_08-model_00-model_states.pt b/100m_new/checkpoints_arm_best_on_100m/global_step9532/layer_08-model_00-model_states.pt new file mode 100644 index 0000000000000000000000000000000000000000..2e04f9118f60a078b44ef937aaeab01b61a62b2e --- /dev/null +++ b/100m_new/checkpoints_arm_best_on_100m/global_step9532/layer_08-model_00-model_states.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0b13e0a6ec6127c79b264bf350551bad92eae71d9e857fcb2164addc194d96fd +size 14161774 diff --git a/100m_new/checkpoints_arm_best_on_100m/global_step9532/layer_09-model_00-model_states.pt b/100m_new/checkpoints_arm_best_on_100m/global_step9532/layer_09-model_00-model_states.pt new file mode 100644 index 0000000000000000000000000000000000000000..7827dce6a7f5867489efb24c8b01922e235b3b61 --- /dev/null +++ b/100m_new/checkpoints_arm_best_on_100m/global_step9532/layer_09-model_00-model_states.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1803f75888421837cf4eb4d46215b9d51db05e340ebef3da2bf5bed9436096e2 +size 14161774 diff --git a/100m_new/checkpoints_arm_best_on_100m/global_step9532/layer_10-model_00-model_states.pt b/100m_new/checkpoints_arm_best_on_100m/global_step9532/layer_10-model_00-model_states.pt new file mode 100644 index 0000000000000000000000000000000000000000..23719c9cef5f0befcb5302bb0412cb99b6a69014 --- /dev/null +++ b/100m_new/checkpoints_arm_best_on_100m/global_step9532/layer_10-model_00-model_states.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e68c1b3c90de28d31e4542f1ecc49c67222ecc014750846908815f86f4d34d53 +size 14161774 diff --git a/100m_new/checkpoints_arm_best_on_100m/global_step9532/layer_11-model_00-model_states.pt b/100m_new/checkpoints_arm_best_on_100m/global_step9532/layer_11-model_00-model_states.pt new file mode 100644 index 0000000000000000000000000000000000000000..7c45b90bcb6c458d95fffbbff853500d6d8ab1b9 --- /dev/null +++ b/100m_new/checkpoints_arm_best_on_100m/global_step9532/layer_11-model_00-model_states.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4995e8031892a0883b65d511544f6b878d44376fd5436c0dee414bf92cb22373 +size 14161774 diff --git a/100m_new/checkpoints_arm_best_on_100m/global_step9532/layer_12-model_00-model_states.pt b/100m_new/checkpoints_arm_best_on_100m/global_step9532/layer_12-model_00-model_states.pt new file mode 100644 index 0000000000000000000000000000000000000000..365629cf335207e45e1dfdc33acb61952765257c --- /dev/null +++ b/100m_new/checkpoints_arm_best_on_100m/global_step9532/layer_12-model_00-model_states.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d7a6a8ff8af668b5b43e21158fb7a9a5e76201a73920685cb2b71762e0de6708 +size 14161774 diff --git a/100m_new/checkpoints_arm_best_on_100m/global_step9532/layer_13-model_00-model_states.pt b/100m_new/checkpoints_arm_best_on_100m/global_step9532/layer_13-model_00-model_states.pt new file mode 100644 index 0000000000000000000000000000000000000000..1c57d3eda0129280ec9976c8ffc0ad77af6cbd5d --- /dev/null +++ b/100m_new/checkpoints_arm_best_on_100m/global_step9532/layer_13-model_00-model_states.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1fd80a3781714f56e9db3b1f9b63a67a14fef88e954cb87d171e381c18648b63 +size 14161774 diff --git a/100m_new/checkpoints_arm_best_on_100m/global_step9532/layer_14-model_00-model_states.pt b/100m_new/checkpoints_arm_best_on_100m/global_step9532/layer_14-model_00-model_states.pt new file mode 100644 index 0000000000000000000000000000000000000000..a930066eaf41099e25a13f855778530ca9ad4ee8 --- /dev/null +++ b/100m_new/checkpoints_arm_best_on_100m/global_step9532/layer_14-model_00-model_states.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6afd48c6ce373261247cb2d18f47166e27a1056ad01bc8485cf35510de6d1f96 +size 14161774 diff --git a/100m_new/checkpoints_arm_best_on_100m/global_step9532/layer_15-model_00-model_states.pt b/100m_new/checkpoints_arm_best_on_100m/global_step9532/layer_15-model_00-model_states.pt new file mode 100644 index 0000000000000000000000000000000000000000..4146e7d2e0a3a69ed40d9cae0b0a2a6d0eb97c38 --- /dev/null +++ b/100m_new/checkpoints_arm_best_on_100m/global_step9532/layer_15-model_00-model_states.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f1ee5bed9728287b1ac430d94cea80ed7670be9f6be94bfe4425cb53c6f9ce8b +size 14161774 diff --git a/100m_new/checkpoints_arm_best_on_100m/global_step9532/layer_16-model_00-model_states.pt b/100m_new/checkpoints_arm_best_on_100m/global_step9532/layer_16-model_00-model_states.pt new file mode 100644 index 0000000000000000000000000000000000000000..e3f7165ef2481fd0274c1051eb68e7554ede5178 --- /dev/null +++ b/100m_new/checkpoints_arm_best_on_100m/global_step9532/layer_16-model_00-model_states.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:57cc7160f0e2f923696229a83baed9b2a1854eef4875d8f97315a07811fd4d96 +size 14161774 diff --git a/100m_new/checkpoints_arm_best_on_100m/global_step9532/layer_17-model_00-model_states.pt b/100m_new/checkpoints_arm_best_on_100m/global_step9532/layer_17-model_00-model_states.pt new file mode 100644 index 0000000000000000000000000000000000000000..9de1dd2c225f2de44576544b0f99d6b1c4e58217 --- /dev/null +++ b/100m_new/checkpoints_arm_best_on_100m/global_step9532/layer_17-model_00-model_states.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:47c136bf944be5508f97a1572203b471f82945007760c7c97f747f398d3e9ac5 +size 2959 diff --git a/100m_new/checkpoints_arm_best_on_100m/global_step9532/mp_rank_00_model_states.pt b/100m_new/checkpoints_arm_best_on_100m/global_step9532/mp_rank_00_model_states.pt new file mode 100644 index 0000000000000000000000000000000000000000..8864c13015414e25d5ad7532abe650b6cc22ca8a --- /dev/null +++ b/100m_new/checkpoints_arm_best_on_100m/global_step9532/mp_rank_00_model_states.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6e98e2ee787563273820553e7a1fd8301d13fbbc29070623829b5ded885a8876 +size 8422812 diff --git a/100m_new/checkpoints_arm_best_on_100m/latest b/100m_new/checkpoints_arm_best_on_100m/latest new file mode 100644 index 0000000000000000000000000000000000000000..ed78d7d3e7b2b52266ff43eb52825690eef8d0c2 --- /dev/null +++ b/100m_new/checkpoints_arm_best_on_100m/latest @@ -0,0 +1 @@ +global_step9532 \ No newline at end of file diff --git a/100m_new/checkpoints_arm_best_on_100m/latest_checkpointed_iteration.txt b/100m_new/checkpoints_arm_best_on_100m/latest_checkpointed_iteration.txt new file mode 100644 index 0000000000000000000000000000000000000000..6bee3d242c126c9b487d7d5786f72a8f35165117 --- /dev/null +++ b/100m_new/checkpoints_arm_best_on_100m/latest_checkpointed_iteration.txt @@ -0,0 +1 @@ +9532 \ No newline at end of file diff --git a/100m_new/checkpoints_arm_best_on_100m/zero_to_fp32.py b/100m_new/checkpoints_arm_best_on_100m/zero_to_fp32.py new file mode 100644 index 0000000000000000000000000000000000000000..0e759146cadd92ddfefab3680146c2bd6a2b5c04 --- /dev/null +++ b/100m_new/checkpoints_arm_best_on_100m/zero_to_fp32.py @@ -0,0 +1,760 @@ +#!/usr/bin/env python + +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +# This script extracts fp32 consolidated weights from a zero 1, 2 and 3 DeepSpeed checkpoints. It gets +# copied into the top level checkpoint dir, so the user can easily do the conversion at any point in +# the future. Once extracted, the weights don't require DeepSpeed and can be used in any +# application. +# +# example: +# python zero_to_fp32.py . output_dir/ +# or +# python zero_to_fp32.py . output_dir/ --safe_serialization + +import argparse +import torch +import glob +import math +import os +import re +import gc +import json +import numpy as np +from tqdm import tqdm +from collections import OrderedDict +from dataclasses import dataclass + +# while this script doesn't use deepspeed to recover data, since the checkpoints are pickled with +# DeepSpeed data structures it has to be available in the current python environment. +from deepspeed.utils import logger +from deepspeed.checkpoint.constants import (DS_VERSION, OPTIMIZER_STATE_DICT, SINGLE_PARTITION_OF_FP32_GROUPS, + FP32_FLAT_GROUPS, ZERO_STAGE, PARTITION_COUNT, PARAM_SHAPES, BUFFER_NAMES, + FROZEN_PARAM_SHAPES, FROZEN_PARAM_FRAGMENTS) + + +@dataclass +class zero_model_state: + buffers: dict() + param_shapes: dict() + shared_params: list + ds_version: int + frozen_param_shapes: dict() + frozen_param_fragments: dict() + + +debug = 0 + +# load to cpu +device = torch.device('cpu') + + +def atoi(text): + return int(text) if text.isdigit() else text + + +def natural_keys(text): + ''' + alist.sort(key=natural_keys) sorts in human order + http://nedbatchelder.com/blog/200712/human_sorting.html + (See Toothy's implementation in the comments) + ''' + return [atoi(c) for c in re.split(r'(\d+)', text)] + + +def get_model_state_file(checkpoint_dir, zero_stage): + if not os.path.isdir(checkpoint_dir): + raise FileNotFoundError(f"Directory '{checkpoint_dir}' doesn't exist") + + # there should be only one file + if zero_stage <= 2: + file = os.path.join(checkpoint_dir, "mp_rank_00_model_states.pt") + elif zero_stage == 3: + file = os.path.join(checkpoint_dir, "zero_pp_rank_0_mp_rank_00_model_states.pt") + + if not os.path.exists(file): + raise FileNotFoundError(f"can't find model states file at '{file}'") + + return file + + +def get_checkpoint_files(checkpoint_dir, glob_pattern): + # XXX: need to test that this simple glob rule works for multi-node setup too + ckpt_files = sorted(glob.glob(os.path.join(checkpoint_dir, glob_pattern)), key=natural_keys) + + if len(ckpt_files) == 0: + raise FileNotFoundError(f"can't find {glob_pattern} files in directory '{checkpoint_dir}'") + + return ckpt_files + + +def get_optim_files(checkpoint_dir): + return get_checkpoint_files(checkpoint_dir, "*_optim_states.pt") + + +def get_model_state_files(checkpoint_dir): + return get_checkpoint_files(checkpoint_dir, "*_model_states.pt") + + +def parse_model_states(files): + zero_model_states = [] + for file in files: + state_dict = torch.load(file, map_location=device, weights_only=False) + + if BUFFER_NAMES not in state_dict: + raise ValueError(f"{file} is not a model state checkpoint") + buffer_names = state_dict[BUFFER_NAMES] + if debug: + print("Found buffers:", buffer_names) + + # recover just the buffers while restoring them to fp32 if they were saved in fp16 + buffers = {k: v.float() for k, v in state_dict["module"].items() if k in buffer_names} + param_shapes = state_dict[PARAM_SHAPES] + + # collect parameters that are included in param_shapes + param_names = [] + for s in param_shapes: + for name in s.keys(): + param_names.append(name) + + # update with frozen parameters + frozen_param_shapes = state_dict.get(FROZEN_PARAM_SHAPES, None) + if frozen_param_shapes is not None: + if debug: + print(f"Found frozen_param_shapes: {frozen_param_shapes}") + param_names += list(frozen_param_shapes.keys()) + + # handle shared params + shared_params = [[k, v] for k, v in state_dict["shared_params"].items()] + + ds_version = state_dict.get(DS_VERSION, None) + + frozen_param_fragments = state_dict.get(FROZEN_PARAM_FRAGMENTS, None) + + z_model_state = zero_model_state(buffers=buffers, + param_shapes=param_shapes, + shared_params=shared_params, + ds_version=ds_version, + frozen_param_shapes=frozen_param_shapes, + frozen_param_fragments=frozen_param_fragments) + zero_model_states.append(z_model_state) + + return zero_model_states + + +def parse_optim_states(files, ds_checkpoint_dir): + total_files = len(files) + state_dicts = [] + for f in tqdm(files, desc='Loading checkpoint shards'): + state_dict = torch.load(f, map_location=device, mmap=True, weights_only=False) + # immediately discard the potentially huge 2 optimizer states as we only care for fp32 master weights + # and also handle the case where it was already removed by another helper script + state_dict["optimizer_state_dict"].pop("optimizer_state_dict", None) + state_dicts.append(state_dict) + + if not ZERO_STAGE in state_dicts[0][OPTIMIZER_STATE_DICT]: + raise ValueError(f"{files[0]} is not a zero checkpoint") + zero_stage = state_dicts[0][OPTIMIZER_STATE_DICT][ZERO_STAGE] + world_size = state_dicts[0][OPTIMIZER_STATE_DICT][PARTITION_COUNT] + + # For ZeRO-2 each param group can have different partition_count as data parallelism for expert + # parameters can be different from data parallelism for non-expert parameters. So we can just + # use the max of the partition_count to get the dp world_size. + + if type(world_size) is list: + world_size = max(world_size) + + if world_size != total_files: + raise ValueError( + f"Expected {world_size} of '*_optim_states.pt' under '{ds_checkpoint_dir}' but found {total_files} files. " + "Possibly due to an overwrite of an old checkpoint, or a checkpoint didn't get saved by one or more processes." + ) + + # the groups are named differently in each stage + if zero_stage <= 2: + fp32_groups_key = SINGLE_PARTITION_OF_FP32_GROUPS + elif zero_stage == 3: + fp32_groups_key = FP32_FLAT_GROUPS + else: + raise ValueError(f"unknown zero stage {zero_stage}") + + fp32_flat_groups = [state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key] for i in range(len(state_dicts))] + return zero_stage, world_size, fp32_flat_groups + + +def _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir, exclude_frozen_parameters): + """ + Returns fp32 state_dict reconstructed from ds checkpoint + + Args: + - ``ds_checkpoint_dir``: path to the deepspeed checkpoint folder (where the optimizer files are) + + """ + print(f"Processing zero checkpoint '{ds_checkpoint_dir}'") + + optim_files = get_optim_files(ds_checkpoint_dir) + zero_stage, world_size, fp32_flat_groups = parse_optim_states(optim_files, ds_checkpoint_dir) + print(f"Detected checkpoint of type zero stage {zero_stage}, world_size: {world_size}") + + model_files = get_model_state_files(ds_checkpoint_dir) + + zero_model_states = parse_model_states(model_files) + print(f'Parsing checkpoint created by deepspeed=={zero_model_states[0].ds_version}') + + if zero_stage <= 2: + return _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states, + exclude_frozen_parameters) + elif zero_stage == 3: + return _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states, + exclude_frozen_parameters) + + +def _zero2_merge_frozen_params(state_dict, zero_model_states): + if zero_model_states[0].frozen_param_shapes is None or len(zero_model_states[0].frozen_param_shapes) == 0: + return + + frozen_param_shapes = zero_model_states[0].frozen_param_shapes + frozen_param_fragments = zero_model_states[0].frozen_param_fragments + + if debug: + num_elem = sum(s.numel() for s in frozen_param_shapes.values()) + print(f'rank 0: {FROZEN_PARAM_SHAPES}.numel = {num_elem}') + + wanted_params = len(frozen_param_shapes) + wanted_numel = sum(s.numel() for s in frozen_param_shapes.values()) + avail_numel = sum([p.numel() for p in frozen_param_fragments.values()]) + print(f'Frozen params: Have {avail_numel} numels to process.') + print(f'Frozen params: Need {wanted_numel} numels in {wanted_params} params') + + total_params = 0 + total_numel = 0 + for name, shape in frozen_param_shapes.items(): + total_params += 1 + unpartitioned_numel = shape.numel() + total_numel += unpartitioned_numel + + state_dict[name] = frozen_param_fragments[name] + + if debug: + print(f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} ") + + print(f"Reconstructed Frozen fp32 state dict with {total_params} params {total_numel} elements") + + +def _has_callable(obj, fn): + attr = getattr(obj, fn, None) + return callable(attr) + + +def _zero2_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states): + param_shapes = zero_model_states[0].param_shapes + + # Reconstruction protocol: + # + # XXX: document this + + if debug: + for i in range(world_size): + for j in range(len(fp32_flat_groups[0])): + print(f"{FP32_FLAT_GROUPS}[{i}][{j}].shape={fp32_flat_groups[i][j].shape}") + + # XXX: memory usage doubles here (zero2) + num_param_groups = len(fp32_flat_groups[0]) + merged_single_partition_of_fp32_groups = [] + for i in range(num_param_groups): + merged_partitions = [sd[i] for sd in fp32_flat_groups] + full_single_fp32_vector = torch.cat(merged_partitions, 0) + merged_single_partition_of_fp32_groups.append(full_single_fp32_vector) + avail_numel = sum( + [full_single_fp32_vector.numel() for full_single_fp32_vector in merged_single_partition_of_fp32_groups]) + + if debug: + wanted_params = sum([len(shapes) for shapes in param_shapes]) + wanted_numel = sum([sum(shape.numel() for shape in shapes.values()) for shapes in param_shapes]) + # not asserting if there is a mismatch due to possible padding + print(f"Have {avail_numel} numels to process.") + print(f"Need {wanted_numel} numels in {wanted_params} params.") + + # params + # XXX: for huge models that can't fit into the host's RAM we will have to recode this to support + # out-of-core computing solution + total_numel = 0 + total_params = 0 + for shapes, full_single_fp32_vector in zip(param_shapes, merged_single_partition_of_fp32_groups): + offset = 0 + avail_numel = full_single_fp32_vector.numel() + for name, shape in shapes.items(): + + unpartitioned_numel = shape.numel() if _has_callable(shape, 'numel') else math.prod(shape) + total_numel += unpartitioned_numel + total_params += 1 + + if debug: + print(f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} ") + state_dict[name] = full_single_fp32_vector.narrow(0, offset, unpartitioned_numel).view(shape) + offset += unpartitioned_numel + + # Z2 started to align to 2*world_size to improve nccl performance. Therefore both offset and + # avail_numel can differ by anywhere between 0..2*world_size. Due to two unrelated complex + # paddings performed in the code it's almost impossible to predict the exact numbers w/o the + # live optimizer object, so we are checking that the numbers are within the right range + align_to = 2 * world_size + + def zero2_align(x): + return align_to * math.ceil(x / align_to) + + if debug: + print(f"original offset={offset}, avail_numel={avail_numel}") + + offset = zero2_align(offset) + avail_numel = zero2_align(avail_numel) + + if debug: + print(f"aligned offset={offset}, avail_numel={avail_numel}") + + # Sanity check + if offset != avail_numel: + raise ValueError(f"consumed {offset} numels out of {avail_numel} - something is wrong") + + print(f"Reconstructed fp32 state dict with {total_params} params {total_numel} elements") + + +def _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states, + exclude_frozen_parameters): + state_dict = OrderedDict() + + # buffers + buffers = zero_model_states[0].buffers + state_dict.update(buffers) + if debug: + print(f"added {len(buffers)} buffers") + + if not exclude_frozen_parameters: + _zero2_merge_frozen_params(state_dict, zero_model_states) + + _zero2_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states) + + # recover shared parameters + for pair in zero_model_states[0].shared_params: + if pair[1] in state_dict: + state_dict[pair[0]] = state_dict[pair[1]] + + return state_dict + + +def zero3_partitioned_param_info(unpartitioned_numel, world_size): + remainder = unpartitioned_numel % world_size + padding_numel = (world_size - remainder) if remainder else 0 + partitioned_numel = math.ceil(unpartitioned_numel / world_size) + return partitioned_numel, padding_numel + + +def _zero3_merge_frozen_params(state_dict, world_size, zero_model_states): + if zero_model_states[0].frozen_param_shapes is None or len(zero_model_states[0].frozen_param_shapes) == 0: + return + + if debug: + for i in range(world_size): + num_elem = sum(s.numel() for s in zero_model_states[i].frozen_param_fragments.values()) + print(f'rank {i}: {FROZEN_PARAM_SHAPES}.numel = {num_elem}') + + frozen_param_shapes = zero_model_states[0].frozen_param_shapes + wanted_params = len(frozen_param_shapes) + wanted_numel = sum(s.numel() for s in frozen_param_shapes.values()) + avail_numel = sum([p.numel() for p in zero_model_states[0].frozen_param_fragments.values()]) * world_size + print(f'Frozen params: Have {avail_numel} numels to process.') + print(f'Frozen params: Need {wanted_numel} numels in {wanted_params} params') + + total_params = 0 + total_numel = 0 + for name, shape in zero_model_states[0].frozen_param_shapes.items(): + total_params += 1 + unpartitioned_numel = shape.numel() + total_numel += unpartitioned_numel + + param_frags = tuple(model_state.frozen_param_fragments[name] for model_state in zero_model_states) + state_dict[name] = torch.cat(param_frags, 0).narrow(0, 0, unpartitioned_numel).view(shape) + + partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(unpartitioned_numel, world_size) + + if debug: + print( + f"Frozen params: {total_params} {name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}" + ) + + print(f"Reconstructed Frozen fp32 state dict with {total_params} params {total_numel} elements") + + +class GatheredTensor: + """ + A pseudo tensor that collects partitioned weights. + It is more memory efficient when there are multiple groups. + """ + + def __init__(self, flat_groups, flat_groups_offset, offset, partitioned_numel, shape): + self.flat_groups = flat_groups + self.flat_groups_offset = flat_groups_offset + self.offset = offset + self.partitioned_numel = partitioned_numel + self.shape = shape + self.dtype = self.flat_groups[0][0].dtype + + def contiguous(self): + """ + Merge partitioned weights from flat_groups into a single tensor. + """ + end_idx = self.offset + self.partitioned_numel + world_size = len(self.flat_groups) + pad_flat_param_chunks = [] + + for rank_i in range(world_size): + # for each rank, we need to collect weights from related group/groups + flat_groups_at_rank_i = self.flat_groups[rank_i] + start_group_id = None + end_group_id = None + for group_id in range(len(self.flat_groups_offset)): + if self.flat_groups_offset[group_id] <= self.offset < self.flat_groups_offset[group_id + 1]: + start_group_id = group_id + if self.flat_groups_offset[group_id] < end_idx <= self.flat_groups_offset[group_id + 1]: + end_group_id = group_id + break + # collect weights from related group/groups + for group_id in range(start_group_id, end_group_id + 1): + flat_tensor = flat_groups_at_rank_i[group_id] + start_offset = self.offset - self.flat_groups_offset[group_id] + end_offset = min(end_idx, self.flat_groups_offset[group_id + 1]) - self.flat_groups_offset[group_id] + pad_flat_param_chunks.append(flat_tensor[start_offset:end_offset]) + + # collect weights from all ranks + pad_flat_param = torch.cat(pad_flat_param_chunks, dim=0) + param = pad_flat_param[:self.shape.numel()].view(self.shape).contiguous() + return param + + +def _zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states): + param_shapes = zero_model_states[0].param_shapes + avail_numel = sum([flat_group.numel() for flat_group in fp32_flat_groups[0]]) * world_size + + # Reconstruction protocol: For zero3 we need to zip the partitions together at boundary of each + # param, re-consolidating each param, while dealing with padding if any + + # merge list of dicts, preserving order + param_shapes = {k: v for d in param_shapes for k, v in d.items()} + + if debug: + for i in range(world_size): + print(f"{FP32_FLAT_GROUPS}[{i}].shape={fp32_flat_groups[i].shape}") + + wanted_params = len(param_shapes) + wanted_numel = sum(shape.numel() for shape in param_shapes.values()) + # not asserting if there is a mismatch due to possible padding + avail_numel = fp32_flat_groups[0].numel() * world_size + print(f"Trainable params: Have {avail_numel} numels to process.") + print(f"Trainable params: Need {wanted_numel} numels in {wanted_params} params.") + + # params + # XXX: for huge models that can't fit into the host's RAM we will have to recode this to support + # out-of-core computing solution + offset = 0 + total_numel = 0 + total_params = 0 + flat_groups_offset = [0] + list(np.cumsum([flat_tensor.numel() for flat_tensor in fp32_flat_groups[0]])) + for name, shape in tqdm(param_shapes.items(), desc='Gathering sharded weights'): + unpartitioned_numel = shape.numel() + total_numel += unpartitioned_numel + total_params += 1 + partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(unpartitioned_numel, world_size) + + if debug: + print( + f"Trainable params: {total_params} {name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}" + ) + + # memory efficient tensor + tensor = GatheredTensor(fp32_flat_groups, flat_groups_offset, offset, partitioned_numel, shape) + state_dict[name] = tensor + offset += partitioned_numel + + offset *= world_size + + # Sanity check + if offset != avail_numel: + raise ValueError(f"consumed {offset} numels out of {avail_numel} - something is wrong") + + print(f"Reconstructed Trainable fp32 state dict with {total_params} params {total_numel} elements") + + +def _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states, + exclude_frozen_parameters): + state_dict = OrderedDict() + + # buffers + buffers = zero_model_states[0].buffers + state_dict.update(buffers) + if debug: + print(f"added {len(buffers)} buffers") + + if not exclude_frozen_parameters: + _zero3_merge_frozen_params(state_dict, world_size, zero_model_states) + + _zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states) + + # recover shared parameters + for pair in zero_model_states[0].shared_params: + if pair[1] in state_dict: + state_dict[pair[0]] = state_dict[pair[1]] + + return state_dict + + +def to_torch_tensor(state_dict, return_empty_tensor=False): + """ + Convert state_dict of GatheredTensor to torch tensor + """ + torch_state_dict = {} + converted_tensors = {} + for name, tensor in state_dict.items(): + tensor_id = id(tensor) + if tensor_id in converted_tensors: # shared tensors + shared_tensor = torch_state_dict[converted_tensors[tensor_id]] + torch_state_dict[name] = shared_tensor + else: + converted_tensors[tensor_id] = name + if return_empty_tensor: + torch_state_dict[name] = torch.empty(tensor.shape, dtype=tensor.dtype) + else: + torch_state_dict[name] = tensor.contiguous() + return torch_state_dict + + +def get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, + tag=None, + exclude_frozen_parameters=False, + lazy_mode=False): + """ + Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated state_dict that can be loaded with + ``load_state_dict()`` and used for training without DeepSpeed or shared with others, for example + via a model hub. + + Args: + - ``checkpoint_dir``: path to the desired checkpoint folder + - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in 'latest' file. e.g., ``global_step14`` + - ``exclude_frozen_parameters``: exclude frozen parameters + - ``lazy_mode``: get state_dict in lazy mode. It returns a dict of pesduo tensor instead of torch tensor, which is more memory efficient. + Convert the pesduo tensor to torch tensor by ``.contiguous()`` + + Returns: + - pytorch ``state_dict`` + + A typical usage might be :: + + from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint + # do the training and checkpoint saving + state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir) # already on cpu + model = model.cpu() # move to cpu + model.load_state_dict(state_dict) + # submit to model hub or save the model to share with others + + In this example the ``model`` will no longer be usable in the deepspeed context of the same + application. i.e. you will need to re-initialize the deepspeed engine, since + ``model.load_state_dict(state_dict)`` will remove all the deepspeed magic from it. + + If you want it all done for you, use ``load_state_dict_from_zero_checkpoint`` instead. + + Note: the above usage may not work if your application doesn't have sufficient free CPU memory. + You may need to use the offline approach using the ``zero_to_fp32.py`` script that is saved with + the checkpoint. Or you can load state_dict in lazy mode :: + + from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint + state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, lazy_mode=True) # not on cpu + for name, lazy_tensor in state_dict.item(): + tensor = lazy_tensor.contiguous() # to cpu + print(name, tensor) + # del tensor to release memory if it no longer in use + """ + if tag is None: + latest_path = os.path.join(checkpoint_dir, 'latest') + if os.path.isfile(latest_path): + with open(latest_path, 'r') as fd: + tag = fd.read().strip() + else: + raise ValueError(f"Unable to find 'latest' file at {latest_path}") + + ds_checkpoint_dir = os.path.join(checkpoint_dir, tag) + + if not os.path.isdir(ds_checkpoint_dir): + raise FileNotFoundError(f"Directory '{ds_checkpoint_dir}' doesn't exist") + + state_dict = _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir, exclude_frozen_parameters) + if lazy_mode: + return state_dict + else: + return to_torch_tensor(state_dict) + + +def convert_zero_checkpoint_to_fp32_state_dict(checkpoint_dir, + output_dir, + max_shard_size="5GB", + safe_serialization=False, + tag=None, + exclude_frozen_parameters=False): + """ + Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict`` file that can be + loaded with ``torch.load(file)`` + ``load_state_dict()`` and used for training without DeepSpeed. + + Args: + - ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``) + - ``output_dir``: directory to the pytorch fp32 state_dict output files + - ``max_shard_size``: the maximum size for a checkpoint before being sharded, default value is 5GB + - ``safe_serialization``: whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`). + - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14`` + - ``exclude_frozen_parameters``: exclude frozen parameters + """ + + # Dependency pre-check + if safe_serialization: + try: + from safetensors.torch import save_file + except ImportError: + print('If you want to use `safe_serialization`, please `pip install safetensors`') + raise + if max_shard_size is not None: + try: + from huggingface_hub import split_torch_state_dict_into_shards + except ImportError: + print('If you want to use `max_shard_size`, please `pip install huggingface_hub`') + raise + + # Convert zero checkpoint to state_dict + state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, + tag, + exclude_frozen_parameters, + lazy_mode=True) + + # Shard the model if it is too big. + weights_name = "model.safetensors" if safe_serialization else "pytorch_model.bin" + if max_shard_size is not None: + filename_pattern = weights_name.replace(".bin", "{suffix}.bin").replace(".safetensors", "{suffix}.safetensors") + # an memory-efficient approach for sharding + empty_state_dict = to_torch_tensor(state_dict, return_empty_tensor=True) + state_dict_split = split_torch_state_dict_into_shards(empty_state_dict, + filename_pattern=filename_pattern, + max_shard_size=max_shard_size) + else: + from collections import namedtuple + StateDictSplit = namedtuple("StateDictSplit", ["is_sharded", "filename_to_tensors"]) + state_dict_split = StateDictSplit(is_sharded=False, + filename_to_tensors={weights_name: list(state_dict.keys())}) + + # Save the model by shard + os.makedirs(output_dir, exist_ok=True) + filename_to_tensors = state_dict_split.filename_to_tensors.items() + for shard_file, tensors in tqdm(filename_to_tensors, desc="Saving checkpoint shards"): + shard_state_dict = {tensor_name: state_dict[tensor_name] for tensor_name in tensors} + shard_state_dict = to_torch_tensor(shard_state_dict) + output_path = os.path.join(output_dir, shard_file) + if safe_serialization: + save_file(shard_state_dict, output_path, metadata={"format": "pt"}) + else: + torch.save(shard_state_dict, output_path) + # release the memory of current shard + for tensor_name in list(shard_state_dict.keys()): + del state_dict[tensor_name] + del shard_state_dict[tensor_name] + del shard_state_dict + gc.collect() + + # Save index if sharded + if state_dict_split.is_sharded: + index = { + "metadata": state_dict_split.metadata, + "weight_map": state_dict_split.tensor_to_filename, + } + save_index_file = "model.safetensors.index.json" if safe_serialization else "pytorch_model.bin.index.json" + save_index_file = os.path.join(output_dir, save_index_file) + with open(save_index_file, "w", encoding="utf-8") as f: + content = json.dumps(index, indent=2, sort_keys=True) + "\n" + f.write(content) + + +def load_state_dict_from_zero_checkpoint(model, checkpoint_dir, tag=None): + """ + 1. Put the provided model to cpu + 2. Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict`` + 3. Load it into the provided model + + Args: + - ``model``: the model object to update + - ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``) + - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14`` + + Returns: + - ``model`: modified model + + Make sure you have plenty of CPU memory available before you call this function. If you don't + have enough use the ``zero_to_fp32.py`` utility to do the conversion. You will find it + conveniently placed for you in the checkpoint folder. + + A typical usage might be :: + + from deepspeed.utils.zero_to_fp32 import load_state_dict_from_zero_checkpoint + model = load_state_dict_from_zero_checkpoint(trainer.model, checkpoint_dir) + # submit to model hub or save the model to share with others + + Note, that once this was run, the ``model`` will no longer be usable in the deepspeed context + of the same application. i.e. you will need to re-initialize the deepspeed engine, since + ``model.load_state_dict(state_dict)`` will remove all the deepspeed magic from it. + + """ + logger.info(f"Extracting fp32 weights") + state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag) + + logger.info(f"Overwriting model with fp32 weights") + model = model.cpu() + model.load_state_dict(state_dict, strict=False) + + return model + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("checkpoint_dir", + type=str, + help="path to the desired checkpoint folder, e.g., path/checkpoint-12") + parser.add_argument("output_dir", + type=str, + help="directory to the pytorch fp32 state_dict output files" + "(e.g. path/checkpoint-12-output/)") + parser.add_argument( + "--max_shard_size", + type=str, + default="5GB", + help="The maximum size for a checkpoint before being sharded. Checkpoints shard will then be each of size" + "lower than this size. If expressed as a string, needs to be digits followed by a unit (like `5MB`" + "We default it to 5GB in order for models to be able to run easily on free-tier google colab instances" + "without CPU OOM issues.") + parser.add_argument( + "--safe_serialization", + default=False, + action='store_true', + help="Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).") + parser.add_argument("-t", + "--tag", + type=str, + default=None, + help="checkpoint tag used as a unique identifier for checkpoint. e.g., global_step1") + parser.add_argument("--exclude_frozen_parameters", action='store_true', help="exclude frozen parameters") + parser.add_argument("-d", "--debug", action='store_true', help="enable debug") + args = parser.parse_args() + + debug = args.debug + + convert_zero_checkpoint_to_fp32_state_dict(args.checkpoint_dir, + args.output_dir, + max_shard_size=args.max_shard_size, + safe_serialization=args.safe_serialization, + tag=args.tag, + exclude_frozen_parameters=args.exclude_frozen_parameters) diff --git a/100m_new/checkpoints_mdm_best_on_100m/global_step95322/bf16_zero_pp_rank_0_mp_rank_00_optim_states.pt b/100m_new/checkpoints_mdm_best_on_100m/global_step95322/bf16_zero_pp_rank_0_mp_rank_00_optim_states.pt new file mode 100644 index 0000000000000000000000000000000000000000..b7ceecaf472faaa0393788fbf3b5dc2b3b785d03 --- /dev/null +++ b/100m_new/checkpoints_mdm_best_on_100m/global_step95322/bf16_zero_pp_rank_0_mp_rank_00_optim_states.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4463f48710ad1261219877ff2428d9b0ec368d78205176ec3b01d9e31ce73a8c +size 625288752 diff --git a/100m_new/checkpoints_mdm_best_on_100m/global_step95322/bf16_zero_pp_rank_1_mp_rank_00_optim_states.pt b/100m_new/checkpoints_mdm_best_on_100m/global_step95322/bf16_zero_pp_rank_1_mp_rank_00_optim_states.pt new file mode 100644 index 0000000000000000000000000000000000000000..483134736820cda0172df962a08a3be6d32c143f --- /dev/null +++ b/100m_new/checkpoints_mdm_best_on_100m/global_step95322/bf16_zero_pp_rank_1_mp_rank_00_optim_states.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4438b7b988b7cb3f7742e8a691de99827fed921bdfcf742ceb1efc5154076ef7 +size 625289392 diff --git a/100m_new/checkpoints_mdm_best_on_100m/global_step95322/bf16_zero_pp_rank_2_mp_rank_00_optim_states.pt b/100m_new/checkpoints_mdm_best_on_100m/global_step95322/bf16_zero_pp_rank_2_mp_rank_00_optim_states.pt new file mode 100644 index 0000000000000000000000000000000000000000..1357fa8ad7040f5b22e2007e9cca889e695910c8 --- /dev/null +++ b/100m_new/checkpoints_mdm_best_on_100m/global_step95322/bf16_zero_pp_rank_2_mp_rank_00_optim_states.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:713f32e14596d66f451f5ac45e18a9b0336e67c491f95f0a19490770cb6f633b +size 625289392 diff --git a/100m_new/checkpoints_mdm_best_on_100m/global_step95322/bf16_zero_pp_rank_3_mp_rank_00_optim_states.pt b/100m_new/checkpoints_mdm_best_on_100m/global_step95322/bf16_zero_pp_rank_3_mp_rank_00_optim_states.pt new file mode 100644 index 0000000000000000000000000000000000000000..3e53d4d051a216921648fbecba33b6b239877606 --- /dev/null +++ b/100m_new/checkpoints_mdm_best_on_100m/global_step95322/bf16_zero_pp_rank_3_mp_rank_00_optim_states.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:98727e08e1d805b1c58cbe056fcdaf7c71584548ea9a3560323874b10598fd81 +size 625289584 diff --git a/100m_new/checkpoints_mdm_best_on_100m/global_step95322/bf16_zero_pp_rank_4_mp_rank_00_optim_states.pt b/100m_new/checkpoints_mdm_best_on_100m/global_step95322/bf16_zero_pp_rank_4_mp_rank_00_optim_states.pt new file mode 100644 index 0000000000000000000000000000000000000000..c14cc9be5bd8de5824b0bbf2dbbe40129e4489e6 --- /dev/null +++ b/100m_new/checkpoints_mdm_best_on_100m/global_step95322/bf16_zero_pp_rank_4_mp_rank_00_optim_states.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:24eef2072f8309a705c6d74f61d016f429005280faf4ec851df6423290a81504 +size 625289456 diff --git a/100m_new/checkpoints_mdm_best_on_100m/global_step95322/bf16_zero_pp_rank_5_mp_rank_00_optim_states.pt b/100m_new/checkpoints_mdm_best_on_100m/global_step95322/bf16_zero_pp_rank_5_mp_rank_00_optim_states.pt new file mode 100644 index 0000000000000000000000000000000000000000..1a1672756f5a4574341adba5d8626a2d3bedf00a --- /dev/null +++ b/100m_new/checkpoints_mdm_best_on_100m/global_step95322/bf16_zero_pp_rank_5_mp_rank_00_optim_states.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0f995654e11893d198d07cc932b2376b9a35762d1c04cf410acc69dc1d831438 +size 625289392 diff --git a/100m_new/checkpoints_mdm_best_on_100m/global_step95322/bf16_zero_pp_rank_6_mp_rank_00_optim_states.pt b/100m_new/checkpoints_mdm_best_on_100m/global_step95322/bf16_zero_pp_rank_6_mp_rank_00_optim_states.pt new file mode 100644 index 0000000000000000000000000000000000000000..9b0d6f696327c271878a24a38ded2a4f70ec7766 --- /dev/null +++ b/100m_new/checkpoints_mdm_best_on_100m/global_step95322/bf16_zero_pp_rank_6_mp_rank_00_optim_states.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:98c8c4d2b19f09faa9faf018506f7071935448d4af29cd1ec6b314ba4b026329 +size 625289584 diff --git a/100m_new/checkpoints_mdm_best_on_100m/global_step95322/bf16_zero_pp_rank_7_mp_rank_00_optim_states.pt b/100m_new/checkpoints_mdm_best_on_100m/global_step95322/bf16_zero_pp_rank_7_mp_rank_00_optim_states.pt new file mode 100644 index 0000000000000000000000000000000000000000..01ee002568761ffa4e4841cd66a8d52fe749e50c --- /dev/null +++ b/100m_new/checkpoints_mdm_best_on_100m/global_step95322/bf16_zero_pp_rank_7_mp_rank_00_optim_states.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:515fbc8e9fa0c655a7fd1f399fc100de4b78ad6011f7410d093e015404b21a30 +size 625289328 diff --git a/100m_new/checkpoints_mdm_best_on_100m/global_step95322/layer_01-model_00-model_states.pt b/100m_new/checkpoints_mdm_best_on_100m/global_step95322/layer_01-model_00-model_states.pt new file mode 100644 index 0000000000000000000000000000000000000000..dad6f86f3e8aba19e10776b0398977b1b1e97740 --- /dev/null +++ b/100m_new/checkpoints_mdm_best_on_100m/global_step95322/layer_01-model_00-model_states.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:905985b6e7df872457c0bd86b044c53b008051bae673d32f9336ceba3c908aa9 +size 128779663 diff --git a/100m_new/checkpoints_mdm_best_on_100m/global_step95322/layer_02-model_00-model_states.pt b/100m_new/checkpoints_mdm_best_on_100m/global_step95322/layer_02-model_00-model_states.pt new file mode 100644 index 0000000000000000000000000000000000000000..8e0e140514d04615e298b0af41779eeed25083a7 --- /dev/null +++ b/100m_new/checkpoints_mdm_best_on_100m/global_step95322/layer_02-model_00-model_states.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:bab4e72325e01def152f437ecd3a6366c7967dd77d752a7568bb4b1a48f22cb0 +size 39165806 diff --git a/100m_new/checkpoints_mdm_best_on_100m/global_step95322/layer_03-model_00-model_states.pt b/100m_new/checkpoints_mdm_best_on_100m/global_step95322/layer_03-model_00-model_states.pt new file mode 100644 index 0000000000000000000000000000000000000000..50d19e0787b1df45d41c9dde074a1307b8a42c0c --- /dev/null +++ b/100m_new/checkpoints_mdm_best_on_100m/global_step95322/layer_03-model_00-model_states.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:53a2cf20b17bb7761bbd10a4a9f52892ea0ec0635237c1c1a48121c61860624f +size 39165806 diff --git a/100m_new/checkpoints_mdm_best_on_100m/global_step95322/layer_04-model_00-model_states.pt b/100m_new/checkpoints_mdm_best_on_100m/global_step95322/layer_04-model_00-model_states.pt new file mode 100644 index 0000000000000000000000000000000000000000..a9777c1a55b290a6ef2dd75f51966be6bd53b4a3 --- /dev/null +++ b/100m_new/checkpoints_mdm_best_on_100m/global_step95322/layer_04-model_00-model_states.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a80da1f5f05e13be7057facb821ae6e014d89d7e206be7bba17e58d260200f22 +size 39165806 diff --git a/100m_new/checkpoints_mdm_best_on_100m/global_step95322/layer_05-model_00-model_states.pt b/100m_new/checkpoints_mdm_best_on_100m/global_step95322/layer_05-model_00-model_states.pt new file mode 100644 index 0000000000000000000000000000000000000000..3d2be380feeb542e2a2b842b481f0f899c9dc3b9 --- /dev/null +++ b/100m_new/checkpoints_mdm_best_on_100m/global_step95322/layer_05-model_00-model_states.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:906322a133d83faecfb8689782abc01d7bc84a9de7fdb4fdfb6b7ad720355adf +size 39165806 diff --git a/100m_new/checkpoints_mdm_best_on_100m/global_step95322/layer_06-model_00-model_states.pt b/100m_new/checkpoints_mdm_best_on_100m/global_step95322/layer_06-model_00-model_states.pt new file mode 100644 index 0000000000000000000000000000000000000000..05e8eed61d409f7f260f4fb9558f94b5989f4628 --- /dev/null +++ b/100m_new/checkpoints_mdm_best_on_100m/global_step95322/layer_06-model_00-model_states.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b83488c2c0f4784791bb80c25e20c780c0bc9f6eb94fc8873a829877cad977f9 +size 39165806 diff --git a/100m_new/checkpoints_mdm_best_on_100m/global_step95322/layer_07-model_00-model_states.pt b/100m_new/checkpoints_mdm_best_on_100m/global_step95322/layer_07-model_00-model_states.pt new file mode 100644 index 0000000000000000000000000000000000000000..a33fd88b6f02f5d53f1aac00fdb227cbd64576a3 --- /dev/null +++ b/100m_new/checkpoints_mdm_best_on_100m/global_step95322/layer_07-model_00-model_states.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e8e4a49734d466213d0b6ab0dff086211759c67a439a62ca657890476fc68620 +size 39165806 diff --git a/100m_new/checkpoints_mdm_best_on_100m/global_step95322/layer_08-model_00-model_states.pt b/100m_new/checkpoints_mdm_best_on_100m/global_step95322/layer_08-model_00-model_states.pt new file mode 100644 index 0000000000000000000000000000000000000000..0a1a59e306301512b804d08224e45006348f32f4 --- /dev/null +++ b/100m_new/checkpoints_mdm_best_on_100m/global_step95322/layer_08-model_00-model_states.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0e312efec7fa49e721d9d1ddb10b89e949afa475985096d81787150eb557f142 +size 39165806 diff --git a/100m_new/checkpoints_mdm_best_on_100m/global_step95322/layer_09-model_00-model_states.pt b/100m_new/checkpoints_mdm_best_on_100m/global_step95322/layer_09-model_00-model_states.pt new file mode 100644 index 0000000000000000000000000000000000000000..254af9eba9d21d683674d9dad507ed2ca5e1ac0e --- /dev/null +++ b/100m_new/checkpoints_mdm_best_on_100m/global_step95322/layer_09-model_00-model_states.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:23c1028aeca77558b89930ce711a246efae90a3f088fb782f2c45d1f3dcab3a6 +size 39165806 diff --git a/100m_new/checkpoints_mdm_best_on_100m/global_step95322/layer_10-model_00-model_states.pt b/100m_new/checkpoints_mdm_best_on_100m/global_step95322/layer_10-model_00-model_states.pt new file mode 100644 index 0000000000000000000000000000000000000000..aaa35d46e6c8f04fe2d6835e6bd9af9f59b94560 --- /dev/null +++ b/100m_new/checkpoints_mdm_best_on_100m/global_step95322/layer_10-model_00-model_states.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1c6eeb39891c087a27c78ce94287774ce8011d0d77365fecda114eafb450ed6d +size 39165806 diff --git a/100m_new/checkpoints_mdm_best_on_100m/global_step95322/layer_11-model_00-model_states.pt b/100m_new/checkpoints_mdm_best_on_100m/global_step95322/layer_11-model_00-model_states.pt new file mode 100644 index 0000000000000000000000000000000000000000..0348e17f8c5f10b83de67e20436cea82d2e5d779 --- /dev/null +++ b/100m_new/checkpoints_mdm_best_on_100m/global_step95322/layer_11-model_00-model_states.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2925c02ae17a3e1892a22a40e0a2e13683db572b6c66294c8476ba2d29979e62 +size 39165806 diff --git a/100m_new/checkpoints_mdm_best_on_100m/global_step95322/layer_12-model_00-model_states.pt b/100m_new/checkpoints_mdm_best_on_100m/global_step95322/layer_12-model_00-model_states.pt new file mode 100644 index 0000000000000000000000000000000000000000..2ef67443eb9cd1fd5544e4d07941c3a4de8736bf --- /dev/null +++ b/100m_new/checkpoints_mdm_best_on_100m/global_step95322/layer_12-model_00-model_states.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f0b30d2ed3b726794dd4f770224baa16a5eb55ea5b1046840432e160837d104a +size 39165806 diff --git a/100m_new/checkpoints_mdm_best_on_100m/global_step95322/layer_13-model_00-model_states.pt b/100m_new/checkpoints_mdm_best_on_100m/global_step95322/layer_13-model_00-model_states.pt new file mode 100644 index 0000000000000000000000000000000000000000..2dc4228a0fc736891d4740138f19a677e66fb0de --- /dev/null +++ b/100m_new/checkpoints_mdm_best_on_100m/global_step95322/layer_13-model_00-model_states.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d947c1ebded660caa98ff539b64df85608e085a251abf500aa15c3caccbe4378 +size 39165806 diff --git a/100m_new/checkpoints_mdm_best_on_100m/global_step95322/layer_14-model_00-model_states.pt b/100m_new/checkpoints_mdm_best_on_100m/global_step95322/layer_14-model_00-model_states.pt new file mode 100644 index 0000000000000000000000000000000000000000..e03ffbe307fae9363258d2eb9aafc6e3d0fed1bb --- /dev/null +++ b/100m_new/checkpoints_mdm_best_on_100m/global_step95322/layer_14-model_00-model_states.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2d288778ddf28050e6eec6066be05b017de326bd96dc17b85bbd52eedd46de7e +size 39165806 diff --git a/100m_new/checkpoints_mdm_best_on_100m/global_step95322/layer_15-model_00-model_states.pt b/100m_new/checkpoints_mdm_best_on_100m/global_step95322/layer_15-model_00-model_states.pt new file mode 100644 index 0000000000000000000000000000000000000000..9b15e33b88961bfa53cba4a06087533737b180a8 --- /dev/null +++ b/100m_new/checkpoints_mdm_best_on_100m/global_step95322/layer_15-model_00-model_states.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:62f2b973c0f7f91ea33d2e26a3e5290e8c5bdf2521dcce6186a0519e47feb3d1 +size 39165806 diff --git a/100m_new/checkpoints_mdm_best_on_100m/global_step95322/layer_16-model_00-model_states.pt b/100m_new/checkpoints_mdm_best_on_100m/global_step95322/layer_16-model_00-model_states.pt new file mode 100644 index 0000000000000000000000000000000000000000..014d226d9bc76bffe8857edff2ddb2c80bd4f995 --- /dev/null +++ b/100m_new/checkpoints_mdm_best_on_100m/global_step95322/layer_16-model_00-model_states.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:45caa5aaae9d9de7d74fb6b4e81012126fd1b594fa7cd2571f92c9ee73c7f949 +size 39165806 diff --git a/100m_new/checkpoints_mdm_best_on_100m/global_step95322/layer_17-model_00-model_states.pt b/100m_new/checkpoints_mdm_best_on_100m/global_step95322/layer_17-model_00-model_states.pt new file mode 100644 index 0000000000000000000000000000000000000000..a175032024d13cf9852e096d76114c66edeae25e --- /dev/null +++ b/100m_new/checkpoints_mdm_best_on_100m/global_step95322/layer_17-model_00-model_states.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:cabef41b49aed5df37b8989d194bdca6e96fec9147ba1f5dd5991429e1be0422 +size 39165806 diff --git a/100m_new/checkpoints_mdm_best_on_100m/global_step95322/layer_18-model_00-model_states.pt b/100m_new/checkpoints_mdm_best_on_100m/global_step95322/layer_18-model_00-model_states.pt new file mode 100644 index 0000000000000000000000000000000000000000..039e5b2fb1f289760040ccd9fb4f63c4a0b01a82 --- /dev/null +++ b/100m_new/checkpoints_mdm_best_on_100m/global_step95322/layer_18-model_00-model_states.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ddf01d6f098c4b62afd672fbd1674a8f62421c2a5d6a525fe2437885e5828f72 +size 39165806 diff --git a/100m_new/checkpoints_mdm_best_on_100m/global_step95322/layer_19-model_00-model_states.pt b/100m_new/checkpoints_mdm_best_on_100m/global_step95322/layer_19-model_00-model_states.pt new file mode 100644 index 0000000000000000000000000000000000000000..49e51e043c67b2b96964beacd01bf6155a40f5bf --- /dev/null +++ b/100m_new/checkpoints_mdm_best_on_100m/global_step95322/layer_19-model_00-model_states.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:fd33f6264f0aa4eafe655416e4bad78faff45e58c894a23e4195dc0890d71d15 +size 39165806 diff --git a/100m_new/checkpoints_mdm_best_on_100m/global_step95322/layer_20-model_00-model_states.pt b/100m_new/checkpoints_mdm_best_on_100m/global_step95322/layer_20-model_00-model_states.pt new file mode 100644 index 0000000000000000000000000000000000000000..34f1bbda8c98cb078b51ce820b2209fd891ae273 --- /dev/null +++ b/100m_new/checkpoints_mdm_best_on_100m/global_step95322/layer_20-model_00-model_states.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d448c4c1a7ff98eef34c921184fd27216eb540d5b50f306193dfa39f1aeaf1be +size 3983 diff --git a/100m_new/checkpoints_mdm_best_on_100m/global_step95322/mp_rank_00_model_states.pt b/100m_new/checkpoints_mdm_best_on_100m/global_step95322/mp_rank_00_model_states.pt new file mode 100644 index 0000000000000000000000000000000000000000..7fe675087f8cf56510e5ea415e7bf2dc6dd711b2 --- /dev/null +++ b/100m_new/checkpoints_mdm_best_on_100m/global_step95322/mp_rank_00_model_states.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ce36ae6aa73f39ae835b381d3876e118d10c0eb50e1f62f5c06b104d0395f266 +size 5278364 diff --git a/100m_new/checkpoints_mdm_best_on_100m/latest b/100m_new/checkpoints_mdm_best_on_100m/latest new file mode 100644 index 0000000000000000000000000000000000000000..aed64165eb1c579a3b6460f759bb7c0657483524 --- /dev/null +++ b/100m_new/checkpoints_mdm_best_on_100m/latest @@ -0,0 +1 @@ +global_step95322 \ No newline at end of file diff --git a/100m_new/checkpoints_mdm_best_on_100m/latest_checkpointed_iteration.txt b/100m_new/checkpoints_mdm_best_on_100m/latest_checkpointed_iteration.txt new file mode 100644 index 0000000000000000000000000000000000000000..d4b9ec605963d1efa38aaa488d89c4ca98d70a2e --- /dev/null +++ b/100m_new/checkpoints_mdm_best_on_100m/latest_checkpointed_iteration.txt @@ -0,0 +1 @@ +95322 \ No newline at end of file diff --git a/100m_new/checkpoints_mdm_best_on_100m/zero_to_fp32.py b/100m_new/checkpoints_mdm_best_on_100m/zero_to_fp32.py new file mode 100644 index 0000000000000000000000000000000000000000..0e759146cadd92ddfefab3680146c2bd6a2b5c04 --- /dev/null +++ b/100m_new/checkpoints_mdm_best_on_100m/zero_to_fp32.py @@ -0,0 +1,760 @@ +#!/usr/bin/env python + +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +# This script extracts fp32 consolidated weights from a zero 1, 2 and 3 DeepSpeed checkpoints. It gets +# copied into the top level checkpoint dir, so the user can easily do the conversion at any point in +# the future. Once extracted, the weights don't require DeepSpeed and can be used in any +# application. +# +# example: +# python zero_to_fp32.py . output_dir/ +# or +# python zero_to_fp32.py . output_dir/ --safe_serialization + +import argparse +import torch +import glob +import math +import os +import re +import gc +import json +import numpy as np +from tqdm import tqdm +from collections import OrderedDict +from dataclasses import dataclass + +# while this script doesn't use deepspeed to recover data, since the checkpoints are pickled with +# DeepSpeed data structures it has to be available in the current python environment. +from deepspeed.utils import logger +from deepspeed.checkpoint.constants import (DS_VERSION, OPTIMIZER_STATE_DICT, SINGLE_PARTITION_OF_FP32_GROUPS, + FP32_FLAT_GROUPS, ZERO_STAGE, PARTITION_COUNT, PARAM_SHAPES, BUFFER_NAMES, + FROZEN_PARAM_SHAPES, FROZEN_PARAM_FRAGMENTS) + + +@dataclass +class zero_model_state: + buffers: dict() + param_shapes: dict() + shared_params: list + ds_version: int + frozen_param_shapes: dict() + frozen_param_fragments: dict() + + +debug = 0 + +# load to cpu +device = torch.device('cpu') + + +def atoi(text): + return int(text) if text.isdigit() else text + + +def natural_keys(text): + ''' + alist.sort(key=natural_keys) sorts in human order + http://nedbatchelder.com/blog/200712/human_sorting.html + (See Toothy's implementation in the comments) + ''' + return [atoi(c) for c in re.split(r'(\d+)', text)] + + +def get_model_state_file(checkpoint_dir, zero_stage): + if not os.path.isdir(checkpoint_dir): + raise FileNotFoundError(f"Directory '{checkpoint_dir}' doesn't exist") + + # there should be only one file + if zero_stage <= 2: + file = os.path.join(checkpoint_dir, "mp_rank_00_model_states.pt") + elif zero_stage == 3: + file = os.path.join(checkpoint_dir, "zero_pp_rank_0_mp_rank_00_model_states.pt") + + if not os.path.exists(file): + raise FileNotFoundError(f"can't find model states file at '{file}'") + + return file + + +def get_checkpoint_files(checkpoint_dir, glob_pattern): + # XXX: need to test that this simple glob rule works for multi-node setup too + ckpt_files = sorted(glob.glob(os.path.join(checkpoint_dir, glob_pattern)), key=natural_keys) + + if len(ckpt_files) == 0: + raise FileNotFoundError(f"can't find {glob_pattern} files in directory '{checkpoint_dir}'") + + return ckpt_files + + +def get_optim_files(checkpoint_dir): + return get_checkpoint_files(checkpoint_dir, "*_optim_states.pt") + + +def get_model_state_files(checkpoint_dir): + return get_checkpoint_files(checkpoint_dir, "*_model_states.pt") + + +def parse_model_states(files): + zero_model_states = [] + for file in files: + state_dict = torch.load(file, map_location=device, weights_only=False) + + if BUFFER_NAMES not in state_dict: + raise ValueError(f"{file} is not a model state checkpoint") + buffer_names = state_dict[BUFFER_NAMES] + if debug: + print("Found buffers:", buffer_names) + + # recover just the buffers while restoring them to fp32 if they were saved in fp16 + buffers = {k: v.float() for k, v in state_dict["module"].items() if k in buffer_names} + param_shapes = state_dict[PARAM_SHAPES] + + # collect parameters that are included in param_shapes + param_names = [] + for s in param_shapes: + for name in s.keys(): + param_names.append(name) + + # update with frozen parameters + frozen_param_shapes = state_dict.get(FROZEN_PARAM_SHAPES, None) + if frozen_param_shapes is not None: + if debug: + print(f"Found frozen_param_shapes: {frozen_param_shapes}") + param_names += list(frozen_param_shapes.keys()) + + # handle shared params + shared_params = [[k, v] for k, v in state_dict["shared_params"].items()] + + ds_version = state_dict.get(DS_VERSION, None) + + frozen_param_fragments = state_dict.get(FROZEN_PARAM_FRAGMENTS, None) + + z_model_state = zero_model_state(buffers=buffers, + param_shapes=param_shapes, + shared_params=shared_params, + ds_version=ds_version, + frozen_param_shapes=frozen_param_shapes, + frozen_param_fragments=frozen_param_fragments) + zero_model_states.append(z_model_state) + + return zero_model_states + + +def parse_optim_states(files, ds_checkpoint_dir): + total_files = len(files) + state_dicts = [] + for f in tqdm(files, desc='Loading checkpoint shards'): + state_dict = torch.load(f, map_location=device, mmap=True, weights_only=False) + # immediately discard the potentially huge 2 optimizer states as we only care for fp32 master weights + # and also handle the case where it was already removed by another helper script + state_dict["optimizer_state_dict"].pop("optimizer_state_dict", None) + state_dicts.append(state_dict) + + if not ZERO_STAGE in state_dicts[0][OPTIMIZER_STATE_DICT]: + raise ValueError(f"{files[0]} is not a zero checkpoint") + zero_stage = state_dicts[0][OPTIMIZER_STATE_DICT][ZERO_STAGE] + world_size = state_dicts[0][OPTIMIZER_STATE_DICT][PARTITION_COUNT] + + # For ZeRO-2 each param group can have different partition_count as data parallelism for expert + # parameters can be different from data parallelism for non-expert parameters. So we can just + # use the max of the partition_count to get the dp world_size. + + if type(world_size) is list: + world_size = max(world_size) + + if world_size != total_files: + raise ValueError( + f"Expected {world_size} of '*_optim_states.pt' under '{ds_checkpoint_dir}' but found {total_files} files. " + "Possibly due to an overwrite of an old checkpoint, or a checkpoint didn't get saved by one or more processes." + ) + + # the groups are named differently in each stage + if zero_stage <= 2: + fp32_groups_key = SINGLE_PARTITION_OF_FP32_GROUPS + elif zero_stage == 3: + fp32_groups_key = FP32_FLAT_GROUPS + else: + raise ValueError(f"unknown zero stage {zero_stage}") + + fp32_flat_groups = [state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key] for i in range(len(state_dicts))] + return zero_stage, world_size, fp32_flat_groups + + +def _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir, exclude_frozen_parameters): + """ + Returns fp32 state_dict reconstructed from ds checkpoint + + Args: + - ``ds_checkpoint_dir``: path to the deepspeed checkpoint folder (where the optimizer files are) + + """ + print(f"Processing zero checkpoint '{ds_checkpoint_dir}'") + + optim_files = get_optim_files(ds_checkpoint_dir) + zero_stage, world_size, fp32_flat_groups = parse_optim_states(optim_files, ds_checkpoint_dir) + print(f"Detected checkpoint of type zero stage {zero_stage}, world_size: {world_size}") + + model_files = get_model_state_files(ds_checkpoint_dir) + + zero_model_states = parse_model_states(model_files) + print(f'Parsing checkpoint created by deepspeed=={zero_model_states[0].ds_version}') + + if zero_stage <= 2: + return _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states, + exclude_frozen_parameters) + elif zero_stage == 3: + return _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states, + exclude_frozen_parameters) + + +def _zero2_merge_frozen_params(state_dict, zero_model_states): + if zero_model_states[0].frozen_param_shapes is None or len(zero_model_states[0].frozen_param_shapes) == 0: + return + + frozen_param_shapes = zero_model_states[0].frozen_param_shapes + frozen_param_fragments = zero_model_states[0].frozen_param_fragments + + if debug: + num_elem = sum(s.numel() for s in frozen_param_shapes.values()) + print(f'rank 0: {FROZEN_PARAM_SHAPES}.numel = {num_elem}') + + wanted_params = len(frozen_param_shapes) + wanted_numel = sum(s.numel() for s in frozen_param_shapes.values()) + avail_numel = sum([p.numel() for p in frozen_param_fragments.values()]) + print(f'Frozen params: Have {avail_numel} numels to process.') + print(f'Frozen params: Need {wanted_numel} numels in {wanted_params} params') + + total_params = 0 + total_numel = 0 + for name, shape in frozen_param_shapes.items(): + total_params += 1 + unpartitioned_numel = shape.numel() + total_numel += unpartitioned_numel + + state_dict[name] = frozen_param_fragments[name] + + if debug: + print(f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} ") + + print(f"Reconstructed Frozen fp32 state dict with {total_params} params {total_numel} elements") + + +def _has_callable(obj, fn): + attr = getattr(obj, fn, None) + return callable(attr) + + +def _zero2_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states): + param_shapes = zero_model_states[0].param_shapes + + # Reconstruction protocol: + # + # XXX: document this + + if debug: + for i in range(world_size): + for j in range(len(fp32_flat_groups[0])): + print(f"{FP32_FLAT_GROUPS}[{i}][{j}].shape={fp32_flat_groups[i][j].shape}") + + # XXX: memory usage doubles here (zero2) + num_param_groups = len(fp32_flat_groups[0]) + merged_single_partition_of_fp32_groups = [] + for i in range(num_param_groups): + merged_partitions = [sd[i] for sd in fp32_flat_groups] + full_single_fp32_vector = torch.cat(merged_partitions, 0) + merged_single_partition_of_fp32_groups.append(full_single_fp32_vector) + avail_numel = sum( + [full_single_fp32_vector.numel() for full_single_fp32_vector in merged_single_partition_of_fp32_groups]) + + if debug: + wanted_params = sum([len(shapes) for shapes in param_shapes]) + wanted_numel = sum([sum(shape.numel() for shape in shapes.values()) for shapes in param_shapes]) + # not asserting if there is a mismatch due to possible padding + print(f"Have {avail_numel} numels to process.") + print(f"Need {wanted_numel} numels in {wanted_params} params.") + + # params + # XXX: for huge models that can't fit into the host's RAM we will have to recode this to support + # out-of-core computing solution + total_numel = 0 + total_params = 0 + for shapes, full_single_fp32_vector in zip(param_shapes, merged_single_partition_of_fp32_groups): + offset = 0 + avail_numel = full_single_fp32_vector.numel() + for name, shape in shapes.items(): + + unpartitioned_numel = shape.numel() if _has_callable(shape, 'numel') else math.prod(shape) + total_numel += unpartitioned_numel + total_params += 1 + + if debug: + print(f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} ") + state_dict[name] = full_single_fp32_vector.narrow(0, offset, unpartitioned_numel).view(shape) + offset += unpartitioned_numel + + # Z2 started to align to 2*world_size to improve nccl performance. Therefore both offset and + # avail_numel can differ by anywhere between 0..2*world_size. Due to two unrelated complex + # paddings performed in the code it's almost impossible to predict the exact numbers w/o the + # live optimizer object, so we are checking that the numbers are within the right range + align_to = 2 * world_size + + def zero2_align(x): + return align_to * math.ceil(x / align_to) + + if debug: + print(f"original offset={offset}, avail_numel={avail_numel}") + + offset = zero2_align(offset) + avail_numel = zero2_align(avail_numel) + + if debug: + print(f"aligned offset={offset}, avail_numel={avail_numel}") + + # Sanity check + if offset != avail_numel: + raise ValueError(f"consumed {offset} numels out of {avail_numel} - something is wrong") + + print(f"Reconstructed fp32 state dict with {total_params} params {total_numel} elements") + + +def _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states, + exclude_frozen_parameters): + state_dict = OrderedDict() + + # buffers + buffers = zero_model_states[0].buffers + state_dict.update(buffers) + if debug: + print(f"added {len(buffers)} buffers") + + if not exclude_frozen_parameters: + _zero2_merge_frozen_params(state_dict, zero_model_states) + + _zero2_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states) + + # recover shared parameters + for pair in zero_model_states[0].shared_params: + if pair[1] in state_dict: + state_dict[pair[0]] = state_dict[pair[1]] + + return state_dict + + +def zero3_partitioned_param_info(unpartitioned_numel, world_size): + remainder = unpartitioned_numel % world_size + padding_numel = (world_size - remainder) if remainder else 0 + partitioned_numel = math.ceil(unpartitioned_numel / world_size) + return partitioned_numel, padding_numel + + +def _zero3_merge_frozen_params(state_dict, world_size, zero_model_states): + if zero_model_states[0].frozen_param_shapes is None or len(zero_model_states[0].frozen_param_shapes) == 0: + return + + if debug: + for i in range(world_size): + num_elem = sum(s.numel() for s in zero_model_states[i].frozen_param_fragments.values()) + print(f'rank {i}: {FROZEN_PARAM_SHAPES}.numel = {num_elem}') + + frozen_param_shapes = zero_model_states[0].frozen_param_shapes + wanted_params = len(frozen_param_shapes) + wanted_numel = sum(s.numel() for s in frozen_param_shapes.values()) + avail_numel = sum([p.numel() for p in zero_model_states[0].frozen_param_fragments.values()]) * world_size + print(f'Frozen params: Have {avail_numel} numels to process.') + print(f'Frozen params: Need {wanted_numel} numels in {wanted_params} params') + + total_params = 0 + total_numel = 0 + for name, shape in zero_model_states[0].frozen_param_shapes.items(): + total_params += 1 + unpartitioned_numel = shape.numel() + total_numel += unpartitioned_numel + + param_frags = tuple(model_state.frozen_param_fragments[name] for model_state in zero_model_states) + state_dict[name] = torch.cat(param_frags, 0).narrow(0, 0, unpartitioned_numel).view(shape) + + partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(unpartitioned_numel, world_size) + + if debug: + print( + f"Frozen params: {total_params} {name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}" + ) + + print(f"Reconstructed Frozen fp32 state dict with {total_params} params {total_numel} elements") + + +class GatheredTensor: + """ + A pseudo tensor that collects partitioned weights. + It is more memory efficient when there are multiple groups. + """ + + def __init__(self, flat_groups, flat_groups_offset, offset, partitioned_numel, shape): + self.flat_groups = flat_groups + self.flat_groups_offset = flat_groups_offset + self.offset = offset + self.partitioned_numel = partitioned_numel + self.shape = shape + self.dtype = self.flat_groups[0][0].dtype + + def contiguous(self): + """ + Merge partitioned weights from flat_groups into a single tensor. + """ + end_idx = self.offset + self.partitioned_numel + world_size = len(self.flat_groups) + pad_flat_param_chunks = [] + + for rank_i in range(world_size): + # for each rank, we need to collect weights from related group/groups + flat_groups_at_rank_i = self.flat_groups[rank_i] + start_group_id = None + end_group_id = None + for group_id in range(len(self.flat_groups_offset)): + if self.flat_groups_offset[group_id] <= self.offset < self.flat_groups_offset[group_id + 1]: + start_group_id = group_id + if self.flat_groups_offset[group_id] < end_idx <= self.flat_groups_offset[group_id + 1]: + end_group_id = group_id + break + # collect weights from related group/groups + for group_id in range(start_group_id, end_group_id + 1): + flat_tensor = flat_groups_at_rank_i[group_id] + start_offset = self.offset - self.flat_groups_offset[group_id] + end_offset = min(end_idx, self.flat_groups_offset[group_id + 1]) - self.flat_groups_offset[group_id] + pad_flat_param_chunks.append(flat_tensor[start_offset:end_offset]) + + # collect weights from all ranks + pad_flat_param = torch.cat(pad_flat_param_chunks, dim=0) + param = pad_flat_param[:self.shape.numel()].view(self.shape).contiguous() + return param + + +def _zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states): + param_shapes = zero_model_states[0].param_shapes + avail_numel = sum([flat_group.numel() for flat_group in fp32_flat_groups[0]]) * world_size + + # Reconstruction protocol: For zero3 we need to zip the partitions together at boundary of each + # param, re-consolidating each param, while dealing with padding if any + + # merge list of dicts, preserving order + param_shapes = {k: v for d in param_shapes for k, v in d.items()} + + if debug: + for i in range(world_size): + print(f"{FP32_FLAT_GROUPS}[{i}].shape={fp32_flat_groups[i].shape}") + + wanted_params = len(param_shapes) + wanted_numel = sum(shape.numel() for shape in param_shapes.values()) + # not asserting if there is a mismatch due to possible padding + avail_numel = fp32_flat_groups[0].numel() * world_size + print(f"Trainable params: Have {avail_numel} numels to process.") + print(f"Trainable params: Need {wanted_numel} numels in {wanted_params} params.") + + # params + # XXX: for huge models that can't fit into the host's RAM we will have to recode this to support + # out-of-core computing solution + offset = 0 + total_numel = 0 + total_params = 0 + flat_groups_offset = [0] + list(np.cumsum([flat_tensor.numel() for flat_tensor in fp32_flat_groups[0]])) + for name, shape in tqdm(param_shapes.items(), desc='Gathering sharded weights'): + unpartitioned_numel = shape.numel() + total_numel += unpartitioned_numel + total_params += 1 + partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(unpartitioned_numel, world_size) + + if debug: + print( + f"Trainable params: {total_params} {name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}" + ) + + # memory efficient tensor + tensor = GatheredTensor(fp32_flat_groups, flat_groups_offset, offset, partitioned_numel, shape) + state_dict[name] = tensor + offset += partitioned_numel + + offset *= world_size + + # Sanity check + if offset != avail_numel: + raise ValueError(f"consumed {offset} numels out of {avail_numel} - something is wrong") + + print(f"Reconstructed Trainable fp32 state dict with {total_params} params {total_numel} elements") + + +def _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states, + exclude_frozen_parameters): + state_dict = OrderedDict() + + # buffers + buffers = zero_model_states[0].buffers + state_dict.update(buffers) + if debug: + print(f"added {len(buffers)} buffers") + + if not exclude_frozen_parameters: + _zero3_merge_frozen_params(state_dict, world_size, zero_model_states) + + _zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states) + + # recover shared parameters + for pair in zero_model_states[0].shared_params: + if pair[1] in state_dict: + state_dict[pair[0]] = state_dict[pair[1]] + + return state_dict + + +def to_torch_tensor(state_dict, return_empty_tensor=False): + """ + Convert state_dict of GatheredTensor to torch tensor + """ + torch_state_dict = {} + converted_tensors = {} + for name, tensor in state_dict.items(): + tensor_id = id(tensor) + if tensor_id in converted_tensors: # shared tensors + shared_tensor = torch_state_dict[converted_tensors[tensor_id]] + torch_state_dict[name] = shared_tensor + else: + converted_tensors[tensor_id] = name + if return_empty_tensor: + torch_state_dict[name] = torch.empty(tensor.shape, dtype=tensor.dtype) + else: + torch_state_dict[name] = tensor.contiguous() + return torch_state_dict + + +def get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, + tag=None, + exclude_frozen_parameters=False, + lazy_mode=False): + """ + Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated state_dict that can be loaded with + ``load_state_dict()`` and used for training without DeepSpeed or shared with others, for example + via a model hub. + + Args: + - ``checkpoint_dir``: path to the desired checkpoint folder + - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in 'latest' file. e.g., ``global_step14`` + - ``exclude_frozen_parameters``: exclude frozen parameters + - ``lazy_mode``: get state_dict in lazy mode. It returns a dict of pesduo tensor instead of torch tensor, which is more memory efficient. + Convert the pesduo tensor to torch tensor by ``.contiguous()`` + + Returns: + - pytorch ``state_dict`` + + A typical usage might be :: + + from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint + # do the training and checkpoint saving + state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir) # already on cpu + model = model.cpu() # move to cpu + model.load_state_dict(state_dict) + # submit to model hub or save the model to share with others + + In this example the ``model`` will no longer be usable in the deepspeed context of the same + application. i.e. you will need to re-initialize the deepspeed engine, since + ``model.load_state_dict(state_dict)`` will remove all the deepspeed magic from it. + + If you want it all done for you, use ``load_state_dict_from_zero_checkpoint`` instead. + + Note: the above usage may not work if your application doesn't have sufficient free CPU memory. + You may need to use the offline approach using the ``zero_to_fp32.py`` script that is saved with + the checkpoint. Or you can load state_dict in lazy mode :: + + from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint + state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, lazy_mode=True) # not on cpu + for name, lazy_tensor in state_dict.item(): + tensor = lazy_tensor.contiguous() # to cpu + print(name, tensor) + # del tensor to release memory if it no longer in use + """ + if tag is None: + latest_path = os.path.join(checkpoint_dir, 'latest') + if os.path.isfile(latest_path): + with open(latest_path, 'r') as fd: + tag = fd.read().strip() + else: + raise ValueError(f"Unable to find 'latest' file at {latest_path}") + + ds_checkpoint_dir = os.path.join(checkpoint_dir, tag) + + if not os.path.isdir(ds_checkpoint_dir): + raise FileNotFoundError(f"Directory '{ds_checkpoint_dir}' doesn't exist") + + state_dict = _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir, exclude_frozen_parameters) + if lazy_mode: + return state_dict + else: + return to_torch_tensor(state_dict) + + +def convert_zero_checkpoint_to_fp32_state_dict(checkpoint_dir, + output_dir, + max_shard_size="5GB", + safe_serialization=False, + tag=None, + exclude_frozen_parameters=False): + """ + Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict`` file that can be + loaded with ``torch.load(file)`` + ``load_state_dict()`` and used for training without DeepSpeed. + + Args: + - ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``) + - ``output_dir``: directory to the pytorch fp32 state_dict output files + - ``max_shard_size``: the maximum size for a checkpoint before being sharded, default value is 5GB + - ``safe_serialization``: whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`). + - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14`` + - ``exclude_frozen_parameters``: exclude frozen parameters + """ + + # Dependency pre-check + if safe_serialization: + try: + from safetensors.torch import save_file + except ImportError: + print('If you want to use `safe_serialization`, please `pip install safetensors`') + raise + if max_shard_size is not None: + try: + from huggingface_hub import split_torch_state_dict_into_shards + except ImportError: + print('If you want to use `max_shard_size`, please `pip install huggingface_hub`') + raise + + # Convert zero checkpoint to state_dict + state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, + tag, + exclude_frozen_parameters, + lazy_mode=True) + + # Shard the model if it is too big. + weights_name = "model.safetensors" if safe_serialization else "pytorch_model.bin" + if max_shard_size is not None: + filename_pattern = weights_name.replace(".bin", "{suffix}.bin").replace(".safetensors", "{suffix}.safetensors") + # an memory-efficient approach for sharding + empty_state_dict = to_torch_tensor(state_dict, return_empty_tensor=True) + state_dict_split = split_torch_state_dict_into_shards(empty_state_dict, + filename_pattern=filename_pattern, + max_shard_size=max_shard_size) + else: + from collections import namedtuple + StateDictSplit = namedtuple("StateDictSplit", ["is_sharded", "filename_to_tensors"]) + state_dict_split = StateDictSplit(is_sharded=False, + filename_to_tensors={weights_name: list(state_dict.keys())}) + + # Save the model by shard + os.makedirs(output_dir, exist_ok=True) + filename_to_tensors = state_dict_split.filename_to_tensors.items() + for shard_file, tensors in tqdm(filename_to_tensors, desc="Saving checkpoint shards"): + shard_state_dict = {tensor_name: state_dict[tensor_name] for tensor_name in tensors} + shard_state_dict = to_torch_tensor(shard_state_dict) + output_path = os.path.join(output_dir, shard_file) + if safe_serialization: + save_file(shard_state_dict, output_path, metadata={"format": "pt"}) + else: + torch.save(shard_state_dict, output_path) + # release the memory of current shard + for tensor_name in list(shard_state_dict.keys()): + del state_dict[tensor_name] + del shard_state_dict[tensor_name] + del shard_state_dict + gc.collect() + + # Save index if sharded + if state_dict_split.is_sharded: + index = { + "metadata": state_dict_split.metadata, + "weight_map": state_dict_split.tensor_to_filename, + } + save_index_file = "model.safetensors.index.json" if safe_serialization else "pytorch_model.bin.index.json" + save_index_file = os.path.join(output_dir, save_index_file) + with open(save_index_file, "w", encoding="utf-8") as f: + content = json.dumps(index, indent=2, sort_keys=True) + "\n" + f.write(content) + + +def load_state_dict_from_zero_checkpoint(model, checkpoint_dir, tag=None): + """ + 1. Put the provided model to cpu + 2. Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict`` + 3. Load it into the provided model + + Args: + - ``model``: the model object to update + - ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``) + - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14`` + + Returns: + - ``model`: modified model + + Make sure you have plenty of CPU memory available before you call this function. If you don't + have enough use the ``zero_to_fp32.py`` utility to do the conversion. You will find it + conveniently placed for you in the checkpoint folder. + + A typical usage might be :: + + from deepspeed.utils.zero_to_fp32 import load_state_dict_from_zero_checkpoint + model = load_state_dict_from_zero_checkpoint(trainer.model, checkpoint_dir) + # submit to model hub or save the model to share with others + + Note, that once this was run, the ``model`` will no longer be usable in the deepspeed context + of the same application. i.e. you will need to re-initialize the deepspeed engine, since + ``model.load_state_dict(state_dict)`` will remove all the deepspeed magic from it. + + """ + logger.info(f"Extracting fp32 weights") + state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag) + + logger.info(f"Overwriting model with fp32 weights") + model = model.cpu() + model.load_state_dict(state_dict, strict=False) + + return model + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("checkpoint_dir", + type=str, + help="path to the desired checkpoint folder, e.g., path/checkpoint-12") + parser.add_argument("output_dir", + type=str, + help="directory to the pytorch fp32 state_dict output files" + "(e.g. path/checkpoint-12-output/)") + parser.add_argument( + "--max_shard_size", + type=str, + default="5GB", + help="The maximum size for a checkpoint before being sharded. Checkpoints shard will then be each of size" + "lower than this size. If expressed as a string, needs to be digits followed by a unit (like `5MB`" + "We default it to 5GB in order for models to be able to run easily on free-tier google colab instances" + "without CPU OOM issues.") + parser.add_argument( + "--safe_serialization", + default=False, + action='store_true', + help="Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).") + parser.add_argument("-t", + "--tag", + type=str, + default=None, + help="checkpoint tag used as a unique identifier for checkpoint. e.g., global_step1") + parser.add_argument("--exclude_frozen_parameters", action='store_true', help="exclude frozen parameters") + parser.add_argument("-d", "--debug", action='store_true', help="enable debug") + args = parser.parse_args() + + debug = args.debug + + convert_zero_checkpoint_to_fp32_state_dict(args.checkpoint_dir, + args.output_dir, + max_shard_size=args.max_shard_size, + safe_serialization=args.safe_serialization, + tag=args.tag, + exclude_frozen_parameters=args.exclude_frozen_parameters)