File size: 2,199 Bytes
48cce71 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 | import argparse
from pathlib import Path
from pytorch_lightning.utilities.deepspeed import convert_zero_checkpoint_to_fp32_state_dict
if __name__ == '__main__':
## read a path using argparse and pass it to convert_zero_checkpoint_to_fp32_state_dict
parser = argparse.ArgumentParser()
parser.add_argument('--input', type=str, default=None, help='path to the desired checkpoint folder')
parser.add_argument('--output', type=str, default=None, help='path to the pytorch fp32 state_dict output file')
# parser.add_argument('--tag', type=str, help='checkpoint tag used as a unique identifier for checkpoint')
args = parser.parse_args()
if args.output is None:
args.output = Path(args.input) / 'converted.ckpt'
convert_zero_checkpoint_to_fp32_state_dict(args.input, args.output)
# import argparse
# from pathlib import Path
# from pytorch_lightning.utilities.deepspeed import convert_zero_checkpoint_to_fp32_state_dict
# from torch.serialization import add_safe_globals
# from deepspeed.runtime.fp16.loss_scaler import LossScaler
# from deepspeed.runtime.zero.config import ZeroStageEnum
# from deepspeed.utils.tensor_fragment import fragment_address
# import torch
# if __name__ == '__main__':
# # 添加DeepSpeed的LossScaler到安全名单
# add_safe_globals([LossScaler, ZeroStageEnum]) # 添加ZeroStageEnum
# torch.serialization.safe_globals([LossScaler])
# torch.serialization.safe_globals([fragment_address])
# # 读取路径参数
# parser = argparse.ArgumentParser()
# parser.add_argument('--input', type=str, required=True, help='path to the desired checkpoint folder')
# parser.add_argument('--output', type=str, default=None, help='path to the pytorch fp32 state_dict output file')
# args = parser.parse_args()
# # 设置默认输出路径
# if args.output is None:
# args.output = str(Path(args.input).parent / 'converted_fp32.ckpt')
# # 执行转换
# try:
# convert_zero_checkpoint_to_fp32_state_dict(args.input, args.output)
# print(f"Successfully converted checkpoint to: {args.output}")
# except Exception as e:
# print(f"Conversion failed: {str(e)}") |