| import argparse | |
| from pathlib import Path | |
| from pytorch_lightning.utilities.deepspeed import convert_zero_checkpoint_to_fp32_state_dict | |
| if __name__ == '__main__': | |
| ## read a path using argparse and pass it to convert_zero_checkpoint_to_fp32_state_dict | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument('--input', type=str, default=None, help='path to the desired checkpoint folder') | |
| parser.add_argument('--output', type=str, default=None, help='path to the pytorch fp32 state_dict output file') | |
| # parser.add_argument('--tag', type=str, help='checkpoint tag used as a unique identifier for checkpoint') | |
| args = parser.parse_args() | |
| if args.output is None: | |
| args.output = Path(args.input) / 'converted.ckpt' | |
| convert_zero_checkpoint_to_fp32_state_dict(args.input, args.output) | |
| # import argparse | |
| # from pathlib import Path | |
| # from pytorch_lightning.utilities.deepspeed import convert_zero_checkpoint_to_fp32_state_dict | |
| # from torch.serialization import add_safe_globals | |
| # from deepspeed.runtime.fp16.loss_scaler import LossScaler | |
| # from deepspeed.runtime.zero.config import ZeroStageEnum | |
| # from deepspeed.utils.tensor_fragment import fragment_address | |
| # import torch | |
| # if __name__ == '__main__': | |
| # # 添加DeepSpeed的LossScaler到安全名单 | |
| # add_safe_globals([LossScaler, ZeroStageEnum]) # 添加ZeroStageEnum | |
| # torch.serialization.safe_globals([LossScaler]) | |
| # torch.serialization.safe_globals([fragment_address]) | |
| # # 读取路径参数 | |
| # parser = argparse.ArgumentParser() | |
| # parser.add_argument('--input', type=str, required=True, help='path to the desired checkpoint folder') | |
| # parser.add_argument('--output', type=str, default=None, help='path to the pytorch fp32 state_dict output file') | |
| # args = parser.parse_args() | |
| # # 设置默认输出路径 | |
| # if args.output is None: | |
| # args.output = str(Path(args.input).parent / 'converted_fp32.ckpt') | |
| # # 执行转换 | |
| # try: | |
| # convert_zero_checkpoint_to_fp32_state_dict(args.input, args.output) | |
| # print(f"Successfully converted checkpoint to: {args.output}") | |
| # except Exception as e: | |
| # print(f"Conversion failed: {str(e)}") |