#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ Robust ZeRO->fp32 converter for Torch>=2.6 (weights_only=True default). It (1) pre-allowlists common DeepSpeed symbols; (2) on failure, parses the 'Unsupported global: GLOBAL ...' from the exception, allowlists it, and retries. Also provides ConvertAfterSaveCallback for use in stage1.py / stage2.py to run conversion automatically after each checkpoint save when using DeepSpeed. """ import argparse import os import re import importlib from pathlib import Path def _has_add_safe_globals(): try: from torch.serialization import add_safe_globals # noqa: F401 return True except Exception: return False def _add_safe(objs): try: from torch.serialization import add_safe_globals add_safe_globals(objs) except Exception: pass def _try_import_symbol(qualname: str): """ Import 'a.b.c' -> returns object 'c' from module 'a.b'. Returns None if anything fails. """ try: mod_name, attr = qualname.rsplit('.', 1) mod = importlib.import_module(mod_name) return getattr(mod, attr) except Exception: return None def _pre_allowlist_commons(): # Pre-allowlist common DS symbols seen in ZeRO shards commons = [ # FP16 scalers "deepspeed.runtime.fp16.loss_scaler.LossScaler", "deepspeed.runtime.fp16.dynamic_loss_scaler.DynamicLossScaler", # ZeRO enums/config/status "deepspeed.runtime.zero.config.ZeroStageEnum", "deepspeed.runtime.zero.stage_1_and_2.ZeroParamStatus", "deepspeed.runtime.zero.stage_1_and_2.ZeroOptimizerStage2", "deepspeed.runtime.config.DeepSpeedConfig", # You just hit this one: "deepspeed.utils.tensor_fragment.fragment_address", ] objs = [] for qn in commons: obj = _try_import_symbol(qn) if obj is not None: objs.append(obj) if objs: _add_safe(objs) def _extract_unsupported_globals(msg: str): """ Parse error text for lines like: 'Unsupported global: GLOBAL deepspeed.utils.tensor_fragment.fragment_address' Return list of qualified names. """ pats = [ r"Unsupported global:\s+GLOBAL\s+([A-Za-z0-9_\.]+)", r"was not an allowed global.*?\[\s*([A-Za-z0-9_\.]+)\s*\]", ] found = set() for pat in pats: for m in re.finditer(pat, msg): found.add(m.group(1)) return list(found) def convert_zero_to_fp32(ckpt_dir: str, out_path: str, max_retries: int = 5): from pytorch_lightning.utilities.deepspeed import convert_zero_checkpoint_to_fp32_state_dict # Step 0: pre-allowlist common DS symbols (no-op on old torch) if _has_add_safe_globals(): _pre_allowlist_commons() # Step 1: try convert; on failure, parse & allowlist missing globals, then retry last_err = None for attempt in range(1, max_retries + 1): try: convert_zero_checkpoint_to_fp32_state_dict(ckpt_dir, out_path) print(f"[OK] Converted ZeRO checkpoint → {out_path}") return except Exception as e: last_err = e msg = str(e) missing = _extract_unsupported_globals(msg) if _has_add_safe_globals() else [] if not missing: # nothing to auto-allowlist or on old torch -> just bail break objs = [] for qn in missing: obj = _try_import_symbol(qn) if obj is not None: objs.append(obj) if objs: _add_safe(objs) print(f"[Retry {attempt}/{max_retries}] allowlisted: {', '.join(missing)}; retrying…") continue else: # couldn't import any of them break # If we reach here, conversion failed raise last_err def _convert_after_save_callback_class(run_after_train_epoch): """Build a PLC Callback class that runs convert after checkpoint save (DeepSpeed only, rank 0).""" import pytorch_lightning as pl class _ConvertAfterSaveCallback(pl.Callback): def __init__(self, dirpath, save_every_n_epochs): self.dirpath = dirpath.rstrip(os.sep) self.save_every_n_epochs = save_every_n_epochs self._run_after_train = run_after_train_epoch def _maybe_convert(self, trainer): if getattr(trainer, 'global_rank', 0) != 0: return strategy = getattr(trainer, 'strategy', None) if strategy is None or 'DeepSpeed' not in type(strategy).__name__: return epoch = trainer.current_epoch + 1 if epoch % self.save_every_n_epochs != 0: return for cb in trainer.callbacks: if type(cb).__name__ == 'ModelCheckpoint': last_path = getattr(cb, 'last_model_path', None) or getattr(cb, 'best_model_path', None) if not last_path or not os.path.exists(last_path): return out_path = os.path.join(self.dirpath, 'converted.ckpt') try: convert_zero_to_fp32(last_path, out_path) except Exception as e: print(f"[ConvertAfterSave] Conversion failed: {e}") return def on_train_epoch_end(self, trainer, pl_module): if self._run_after_train: self._maybe_convert(trainer) def on_validation_epoch_end(self, trainer, pl_module): if not self._run_after_train: self._maybe_convert(trainer) return _ConvertAfterSaveCallback def ConvertAfterSaveCallback(dirpath, save_every_n_epochs, run_after_train_epoch=True): """Callback instance: after each checkpoint save, run ZeRO->fp32 and write dirpath/converted.ckpt.""" return _convert_after_save_callback_class(run_after_train_epoch)(dirpath, save_every_n_epochs) def main(): parser = argparse.ArgumentParser() parser.add_argument('--input', type=str, required=True, help='Path to the ZeRO checkpoint folder (…/epoch=XX.ckpt/checkpoint)') parser.add_argument('--output', type=str, default=None, help='Path to output fp32 PyTorch state_dict file') args = parser.parse_args() ckpt_dir = Path(args.input) out = Path(args.output) if args.output is not None else (ckpt_dir / 'converted.ckpt') convert_zero_to_fp32(str(ckpt_dir), str(out)) if __name__ == '__main__': main() # import argparse # from pathlib import Path # from pytorch_lightning.utilities.deepspeed import convert_zero_checkpoint_to_fp32_state_dict # if __name__ == '__main__': # ## read a path using argparse and pass it to convert_zero_checkpoint_to_fp32_state_dict # parser = argparse.ArgumentParser() # parser.add_argument('--input', type=str, default=None, help='path to the desired checkpoint folder') # parser.add_argument('--output', type=str, default=None, help='path to the pytorch fp32 state_dict output file') # # parser.add_argument('--tag', type=str, help='checkpoint tag used as a unique identifier for checkpoint') # args = parser.parse_args() # if args.output is None: # args.output = Path(args.input) / 'converted.ckpt' # convert_zero_checkpoint_to_fp32_state_dict(args.input, args.output)