Spaces:
Running on Zero
Running on Zero
| # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. | |
| """ | |
| Utility helpers for mimo models. | |
| """ | |
| import torch | |
| from megatron.core import dist_checkpointing | |
| def load_submodule_ckpt(module: torch.nn.Module, ckpt_dir: str): | |
| """Load *ckpt_dir* into *module* using Megatron distributed-checkpointing.""" | |
| # 1) Ask for tensors using a `module.` prefix so they match checkpoint keys. | |
| sharded_sd_with_prefix = module.sharded_state_dict(prefix="module.") | |
| # Remove fp8 extra_state tensors – they may not exist in older checkpoints. | |
| for k in list(sharded_sd_with_prefix.keys()): | |
| if "extra_state" in k: | |
| del sharded_sd_with_prefix[k] | |
| # 2) Wrap it under a root key just as in user snippet; this becomes the state | |
| # dict returned by `load` so we can easily strip the prefix afterwards. | |
| wrapper_sd = dict(state_dict=sharded_sd_with_prefix) | |
| loaded = dist_checkpointing.load( | |
| sharded_state_dict=wrapper_sd, | |
| checkpoint_dir=ckpt_dir, | |
| ) | |
| # 3) Remove the prefix and push into the module. | |
| cleaned = {k.removeprefix("module."): v for k, v in loaded["state_dict"].items()} | |
| incompatible = module.load_state_dict(cleaned, strict=False) | |
| unexpected = [k for k in incompatible.unexpected_keys if "extra_state" not in k] | |
| missing = [k for k in incompatible.missing_keys if "extra_state" not in k] | |
| if unexpected or missing: | |
| raise RuntimeError( | |
| f"load_state_dict had unexpected mismatch. Missing: {missing}, Unexpected: {unexpected}" | |
| ) | |