File size: 7,476 Bytes
7c15d15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Robust ZeRO->fp32 converter for Torch>=2.6 (weights_only=True default).
It (1) pre-allowlists common DeepSpeed symbols; (2) on failure, parses the
'Unsupported global: GLOBAL ...' from the exception, allowlists it, and retries.

Also provides ConvertAfterSaveCallback for use in stage1.py / stage2.py to
run conversion automatically after each checkpoint save when using DeepSpeed.
"""

import argparse
import os
import re
import importlib
from pathlib import Path

def _has_add_safe_globals():
    try:
        from torch.serialization import add_safe_globals  # noqa: F401
        return True
    except Exception:
        return False

def _add_safe(objs):
    try:
        from torch.serialization import add_safe_globals
        add_safe_globals(objs)
    except Exception:
        pass

def _try_import_symbol(qualname: str):
    """
    Import 'a.b.c' -> returns object 'c' from module 'a.b'.
    Returns None if anything fails.
    """
    try:
        mod_name, attr = qualname.rsplit('.', 1)
        mod = importlib.import_module(mod_name)
        return getattr(mod, attr)
    except Exception:
        return None

def _pre_allowlist_commons():
    # Pre-allowlist common DS symbols seen in ZeRO shards
    commons = [
        # FP16 scalers
        "deepspeed.runtime.fp16.loss_scaler.LossScaler",
        "deepspeed.runtime.fp16.dynamic_loss_scaler.DynamicLossScaler",
        # ZeRO enums/config/status
        "deepspeed.runtime.zero.config.ZeroStageEnum",
        "deepspeed.runtime.zero.stage_1_and_2.ZeroParamStatus",
        "deepspeed.runtime.zero.stage_1_and_2.ZeroOptimizerStage2",
        "deepspeed.runtime.config.DeepSpeedConfig",
        # You just hit this one:
        "deepspeed.utils.tensor_fragment.fragment_address",
    ]
    objs = []
    for qn in commons:
        obj = _try_import_symbol(qn)
        if obj is not None:
            objs.append(obj)
    if objs:
        _add_safe(objs)

def _extract_unsupported_globals(msg: str):
    """
    Parse error text for lines like:
    'Unsupported global: GLOBAL deepspeed.utils.tensor_fragment.fragment_address'
    Return list of qualified names.
    """
    pats = [
        r"Unsupported global:\s+GLOBAL\s+([A-Za-z0-9_\.]+)",
        r"was not an allowed global.*?\[\s*([A-Za-z0-9_\.]+)\s*\]",
    ]
    found = set()
    for pat in pats:
        for m in re.finditer(pat, msg):
            found.add(m.group(1))
    return list(found)

def convert_zero_to_fp32(ckpt_dir: str, out_path: str, max_retries: int = 5):
    from pytorch_lightning.utilities.deepspeed import convert_zero_checkpoint_to_fp32_state_dict

    # Step 0: pre-allowlist common DS symbols (no-op on old torch)
    if _has_add_safe_globals():
        _pre_allowlist_commons()

    # Step 1: try convert; on failure, parse & allowlist missing globals, then retry
    last_err = None
    for attempt in range(1, max_retries + 1):
        try:
            convert_zero_checkpoint_to_fp32_state_dict(ckpt_dir, out_path)
            print(f"[OK] Converted ZeRO checkpoint → {out_path}")
            return
        except Exception as e:
            last_err = e
            msg = str(e)
            missing = _extract_unsupported_globals(msg) if _has_add_safe_globals() else []
            if not missing:
                # nothing to auto-allowlist or on old torch -> just bail
                break
            objs = []
            for qn in missing:
                obj = _try_import_symbol(qn)
                if obj is not None:
                    objs.append(obj)
            if objs:
                _add_safe(objs)
                print(f"[Retry {attempt}/{max_retries}] allowlisted: {', '.join(missing)}; retrying…")
                continue
            else:
                # couldn't import any of them
                break
    # If we reach here, conversion failed
    raise last_err


def _convert_after_save_callback_class(run_after_train_epoch):
    """Build a PLC Callback class that runs convert after checkpoint save (DeepSpeed only, rank 0)."""
    import pytorch_lightning as pl

    class _ConvertAfterSaveCallback(pl.Callback):
        def __init__(self, dirpath, save_every_n_epochs):
            self.dirpath = dirpath.rstrip(os.sep)
            self.save_every_n_epochs = save_every_n_epochs
            self._run_after_train = run_after_train_epoch

        def _maybe_convert(self, trainer):
            if getattr(trainer, 'global_rank', 0) != 0:
                return
            strategy = getattr(trainer, 'strategy', None)
            if strategy is None or 'DeepSpeed' not in type(strategy).__name__:
                return
            epoch = trainer.current_epoch + 1
            if epoch % self.save_every_n_epochs != 0:
                return
            for cb in trainer.callbacks:
                if type(cb).__name__ == 'ModelCheckpoint':
                    last_path = getattr(cb, 'last_model_path', None) or getattr(cb, 'best_model_path', None)
                    if not last_path or not os.path.exists(last_path):
                        return
                    out_path = os.path.join(self.dirpath, 'converted.ckpt')
                    try:
                        convert_zero_to_fp32(last_path, out_path)
                    except Exception as e:
                        print(f"[ConvertAfterSave] Conversion failed: {e}")
                    return

        def on_train_epoch_end(self, trainer, pl_module):
            if self._run_after_train:
                self._maybe_convert(trainer)

        def on_validation_epoch_end(self, trainer, pl_module):
            if not self._run_after_train:
                self._maybe_convert(trainer)

    return _ConvertAfterSaveCallback


def ConvertAfterSaveCallback(dirpath, save_every_n_epochs, run_after_train_epoch=True):
    """Callback instance: after each checkpoint save, run ZeRO->fp32 and write dirpath/converted.ckpt."""
    return _convert_after_save_callback_class(run_after_train_epoch)(dirpath, save_every_n_epochs)


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--input',  type=str, required=True,
                        help='Path to the ZeRO checkpoint folder (…/epoch=XX.ckpt/checkpoint)')
    parser.add_argument('--output', type=str, default=None,
                        help='Path to output fp32 PyTorch state_dict file')
    args = parser.parse_args()

    ckpt_dir = Path(args.input)
    out = Path(args.output) if args.output is not None else (ckpt_dir / 'converted.ckpt')

    convert_zero_to_fp32(str(ckpt_dir), str(out))

if __name__ == '__main__':
    main()



# import argparse
# from pathlib import Path
# from pytorch_lightning.utilities.deepspeed import convert_zero_checkpoint_to_fp32_state_dict

# if __name__ == '__main__':
#     ## read a path using argparse and pass it to convert_zero_checkpoint_to_fp32_state_dict
#     parser = argparse.ArgumentParser()
#     parser.add_argument('--input', type=str, default=None, help='path to the desired checkpoint folder')
#     parser.add_argument('--output', type=str, default=None, help='path to the pytorch fp32 state_dict output file')
#     # parser.add_argument('--tag', type=str, help='checkpoint tag used as a unique identifier for checkpoint')
#     args = parser.parse_args()
#     if args.output is None:
#         args.output = Path(args.input) / 'converted.ckpt'
#     convert_zero_checkpoint_to_fp32_state_dict(args.input, args.output)