File size: 2,199 Bytes
48cce71
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import argparse
from pathlib import Path
from pytorch_lightning.utilities.deepspeed import convert_zero_checkpoint_to_fp32_state_dict

if __name__ == '__main__':
    ## read a path using argparse and pass it to convert_zero_checkpoint_to_fp32_state_dict
    parser = argparse.ArgumentParser()
    parser.add_argument('--input', type=str, default=None, help='path to the desired checkpoint folder')
    parser.add_argument('--output', type=str, default=None, help='path to the pytorch fp32 state_dict output file')
    # parser.add_argument('--tag', type=str, help='checkpoint tag used as a unique identifier for checkpoint')
    args = parser.parse_args()
    if args.output is None:
        args.output = Path(args.input) / 'converted.ckpt'
    convert_zero_checkpoint_to_fp32_state_dict(args.input, args.output)

# import argparse
# from pathlib import Path
# from pytorch_lightning.utilities.deepspeed import convert_zero_checkpoint_to_fp32_state_dict
# from torch.serialization import add_safe_globals
# from deepspeed.runtime.fp16.loss_scaler import LossScaler
# from deepspeed.runtime.zero.config import ZeroStageEnum
# from deepspeed.utils.tensor_fragment import fragment_address
# import torch

# if __name__ == '__main__':
#     # 添加DeepSpeed的LossScaler到安全名单
#     add_safe_globals([LossScaler, ZeroStageEnum])  # 添加ZeroStageEnum

#     torch.serialization.safe_globals([LossScaler])
#     torch.serialization.safe_globals([fragment_address])
#     # 读取路径参数
#     parser = argparse.ArgumentParser()
#     parser.add_argument('--input', type=str, required=True, help='path to the desired checkpoint folder')
#     parser.add_argument('--output', type=str, default=None, help='path to the pytorch fp32 state_dict output file')
#     args = parser.parse_args()

#     # 设置默认输出路径
#     if args.output is None:
#         args.output = str(Path(args.input).parent / 'converted_fp32.ckpt')

#     # 执行转换
#     try:
#         convert_zero_checkpoint_to_fp32_state_dict(args.input, args.output)
#         print(f"Successfully converted checkpoint to: {args.output}")
#     except Exception as e:
#         print(f"Conversion failed: {str(e)}")